Skip to content
Open
26 changes: 26 additions & 0 deletions skyrl-train/skyrl_train/env_vars.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import os


SKYRL_VLLM_DP_PORT_OFFSET = int(os.environ.get("SKYRL_VLLM_DP_PORT_OFFSET", 500))
"""
Offset for the data parallel port of the vLLM server.
"""
SKYRL_WAIT_UNTIL_INFERENCE_SERVER_HEALTHY_TIMEOUT_S = int(
os.environ.get("SKYRL_WAIT_UNTIL_INFERENCE_SERVER_HEALTHY_TIMEOUT_S", 600)
)
"""
Timeout for waiting until the inference server is healthy.
"""
SKYRL_INCLUDE_PYTHONPATH_IN_RUNTIME_ENV = str(
os.environ.get("SKYRL_INCLUDE_PYTHONPATH_IN_RUNTIME_ENV", "False")
).lower() in (
"true",
"1",
"yes",
)
"""
Whether to include the PYTHONPATH environment variable in the runtime
environment. In case of using ray nightly, this will be needed to avoid
dependencies issues by setting it to the local path where ray nightly is
installed.
"""
32 changes: 31 additions & 1 deletion skyrl-train/skyrl_train/inference_engines/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import List, Dict, TypedDict, Any, Optional, Hashable, TYPE_CHECKING
from typing import List, Dict, TypedDict, Any, Optional, Hashable, Protocol, runtime_checkable, TYPE_CHECKING

if TYPE_CHECKING:
from skyrl_train.weight_sync.transfer_strategy import WeightSyncInitInfo
Expand Down Expand Up @@ -31,6 +31,36 @@ class InferenceEngineOutput(TypedDict):
response_logprobs: Optional[List[List[float]]]


@runtime_checkable
class InferenceClientProtocol(Protocol):
"""
Structural protocol for inference clients.

Both InferenceEngineInterface and RemoteInferenceClient satisfy this protocol,
enabling code to work with either implementation via duck typing.

This is the minimal common interface for:
- Data plane: generate()
- Lifecycle: sleep(), wake_up(), teardown()
"""

async def generate(self, input_batch: InferenceEngineInput) -> InferenceEngineOutput:
"""Generate completions for a batch of inputs."""
...

async def sleep(self, *args: Any, **kwargs: Any) -> Any:
"""Put the engine into sleep/low-power mode."""
...

async def wake_up(self, *args: Any, **kwargs: Any) -> Any:
"""Wake the engine from sleep mode."""
...

async def teardown(self) -> None:
"""Clean up resources."""
...


class InferenceEngineInterface(ABC):

@abstractmethod
Expand Down
70 changes: 4 additions & 66 deletions skyrl-train/skyrl_train/inference_engines/vllm/vllm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from dataclasses import dataclass
from http import HTTPStatus
import ray
import torch
import asyncio
import vllm
from types import SimpleNamespace
Expand All @@ -24,7 +23,6 @@
)
from vllm.lora.request import LoRARequest
from uuid import uuid4
import warnings
from skyrl_train.inference_engines.base import (
InferenceEngineInterface,
InferenceEngineInput,
Expand Down Expand Up @@ -66,70 +64,10 @@ def setup_envvars_for_vllm(kwargs, bundle_indices):
logger.info(f"creating LLM with bundle_indices={bundle_indices}")


class WorkerWrap:
def test_rpc(self, *args, **kwargs):
"""Test RPC call to worker"""
return args, kwargs

def init_weight_update_communicator(self, init_info: bytes):
"""Init weight update communicator from init info.

Args:
init_info: Pickled bytes of WeightSyncInitInfo from the sender.
"""
import pickle

assert torch.distributed.is_initialized(), "default torch process group must be initialized"

# Unpickle init_info to restore the original object type
assert isinstance(init_info, bytes), f"Expected bytes, got {type(init_info).__name__}"
init_info = pickle.loads(init_info)

strategy_cls = init_info.strategy_type()

if hasattr(self, "_weight_receiver") and self._weight_receiver is not None:
# TODO(haochen): we should get rid of this flag and override existing receiver.
if init_info.override_existing_receiver:
self._weight_receiver.teardown()
self._weight_receiver = None
else:
warnings.warn(
"Detected an existing weight receiver. "
"For overriding, use `generator.override_existing_update_group=enable`"
)
return

self._weight_receiver = strategy_cls.create_receiver(init_info)

def load_weights(self, request: bytes) -> None:
"""Load weights using the receiver.

This method is called via collective_rpc from VLLMWeightLoader.

Args:
request: Pickled bytes of WeightUpdateRequest.
"""
import pickle

# Unpickle request to restore the original object type
assert isinstance(request, bytes), f"Expected bytes, got {type(request).__name__}"
request = pickle.loads(request)

weight_list = []
for name, tensor in self._weight_receiver.receive_weights(request):
weight_list.append((name, tensor))

self.model_runner.model.load_weights(weights=weight_list)

for weight in weight_list:
del weight

# TODO (sumanthrh): Add destroy process group RPC as a atexit handler to Trainer code.
def teardown_weight_receiver(self):
if not hasattr(self, "_weight_receiver") or self._weight_receiver is None:
warnings.warn("No weight receiver to teardown")
return
self._weight_receiver.teardown()
# Backward compatibility: WorkerWrap has moved to inference_servers.vllm_worker
# This alias preserves the old import path for existing scripts/configs.
# TODO (Kourosh): Remove this alias once all references are updated.
from skyrl_train.inference_servers.vllm_worker import WorkerWrap # noqa: F401, E402


class BaseVLLMInferenceEngine(InferenceEngineInterface):
Expand Down
Empty file.
74 changes: 74 additions & 0 deletions skyrl-train/skyrl_train/inference_servers/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
"""
Common utilities for inference servers.

Uses Ray's public network utilities for consistency with Ray's cluster management.
"""

import logging
import socket
from dataclasses import dataclass

import ray

logger = logging.getLogger(__name__)


@dataclass
class ServerInfo:
"""Information about a running inference server."""

ip: str
port: int

@property
def url(self) -> str:
return f"http://{self.ip}:{self.port}"


def get_node_ip() -> str:
"""
Get the IP address of the current node.

Returns the node IP from Ray's global worker if Ray is initialized
"""
return ray.util.get_node_ip_address()


def get_open_port(start_port: int | None = None) -> int:
"""
Get an available port.

Args:
start_port: If provided, search for an available port starting from this value.
If None, let the OS assign a free port.

Returns:
An available port number.
"""
if start_port is not None:
# Search for available port starting from start_port
port = start_port
while True:
try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
s.bind(("", port))
return port
except OSError:
port += 1
if port > 65535:
raise RuntimeError(f"No available port found starting from {start_port}")

# Let OS assign a free port
# Try IPv4 first
try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("", 0))
return s.getsockname()[1]
except OSError:
pass

# Try IPv6
with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s:
s.bind(("", 0))
return s.getsockname()[1]
105 changes: 105 additions & 0 deletions skyrl-train/skyrl_train/inference_servers/protocols.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
"""
Protocols for inference server components.

These define the interfaces that server implementations must follow.
"""

from argparse import Namespace
from typing import Optional, Protocol, Tuple, runtime_checkable

from skyrl_train.inference_servers.common import ServerInfo


@runtime_checkable
class ServerActorProtocol(Protocol):
"""
Protocol defining the interface for server actor classes.

Any server actor class (vLLM, SGLang, etc.) must implement this interface
to be usable with ServerGroup.

Example:
class MyServerActor(ServerActorProtocol):
@staticmethod
def compute_num_gpus_per_server(cli_args: Namespace) -> int:
return cli_args.tensor_parallel_size

def __init__(self, cli_args, start_port, server_idx, ...):
...

async def start(self) -> ServerInfo:
...
"""

@staticmethod
def compute_num_gpus_per_server(cli_args: Namespace) -> int:
"""
Compute the number of GPUs needed per server instance.

This is called before actor creation to determine placement group size.

Args:
cli_args: Engine-specific CLI arguments.

Returns:
Number of GPUs required per server (e.g., TP * PP for vLLM).
"""
...

def __init__(
self,
cli_args: Namespace,
start_port: int,
server_idx: int,
start_bundle_idx: int,
dp_size: int,
dp_master_address: Optional[str],
dp_rpc_port: Optional[int],
enable_pd: bool,
nixl_side_channel_base: int,
) -> None:
"""
Initialize the server actor.

Args:
cli_args: Engine-specific CLI arguments.
start_port: Base port to search for available port.
server_idx: Index of this server in the group (0-indexed).
start_bundle_idx: Starting bundle index in placement group for this server's workers.
dp_size: Data parallel size (-1 to disable DP).
dp_master_address: DP master address (for non-rank-0 servers).
dp_rpc_port: DP RPC port (for non-rank-0 servers).
enable_pd: Enable prefill-decode disaggregation.
nixl_side_channel_base: Base port for NIXL side channels.
"""
...

def get_server_info(self) -> ServerInfo:
"""Get the server's IP and port info."""
...

def get_dp_info(self) -> Tuple[str, int]:
"""
Get the DP master address and RPC port.

Only called on server_idx=0 when DP is enabled.

Returns:
Tuple of (master_address, rpc_port).
"""
...

async def start(self) -> ServerInfo:
"""
Start the server.

This should block until the server is healthy and ready to serve requests.

Returns:
ServerInfo with the server's IP and port.
"""
...

async def shutdown(self) -> None:
"""Gracefully shutdown the server."""
...
Loading