-
Notifications
You must be signed in to change notification settings - Fork 50
Description
🚀 The feature, motivation and pitch
The current Store implementation enforces a strict separation between Scheduler and Worker roles in the inference engine. This design imposes several critical limitations:
-
Inconvenient lifecycle management: Block allocation (
Alloc) and commit operations (Commit) must be invoked in the Scheduler process, while data transfer is restricted to the Worker process. This forces the Scheduler to pre-allocate blocks before dispatching tasks to Workers, then collect execution results for post-processing, making error handling and control flow unnecessarily complex. -
Limited internal optimization: The rigid role separation prevents efficient implementation of data caching strategies and I/O aggregation mechanisms.
-
Capability gaps: Cannot support emerging scenarios like
LayerWiseandChunkBlockoperations effectively. -
Hindered composability: The hierarchical cache design becomes difficult to extend—some Store instances should focus on DRAM-SSD transfers while others handle HBM-DRAM transfers, but the current interface doesn't support this specialization cleanly.
To address these limitations, we propose a new Store architecture with redesigned interfaces to better support data caching, I/O aggregation, and hierarchical composition.
Alternatives
C++ Base Classes
namespace UC {
/**
* @brief Abstract interface for a key-value store that supports
* asynchronous load/dump of cached blocks.
*
* Thread safety: All public methods must be thread-safe. Concurrent calls
* are allowed; implementations are responsible for internal synchronization.
*/
class Store {
public:
virtual ~Store() = default;
/**
* @brief Check whether the given blocks exist in storage.
*
* @param blocks Array of block identifiers to test.
* @param num Number of block identifiers to test.
* @return Expected<std::vector<uint8_t>>
* - On success: a vector whose i-th element is **true** if blocks[i]
* is present, otherwise **false**.
* - On failure: appropriate Status code.
*/
virtual Expected<std::vector<uint8_t>> Lookup(const Detail::BlockId* blocks, size_t num) = 0;
/**
* @brief Hint the store to prefetch given blocks into high-speed cache.
*
* This call is **non-blocking** and **fire-and-forget**; it returns
* immediately and carries no completion guarantee. Implementations may
* ignore the hint if prefetching is not supported or resources are
* unavailable.
*
* @param blocks Array of block identifiers to be prefetched.
* @param num Number of block identifiers to be prefetched.
*
* @note Thread-safe; may be called concurrently with other operations.
* @note Default implementation does nothing.
*/
virtual void Prefetch(const Detail::BlockId* blocks, size_t num) = 0;
/**
* @brief Start an asynchronous load (storage → device) transfer.
*
* @param task Description of shards to be loaded.
* @return Expected<TaskHandle>
* - On success: a task handle that can be passed to Wait() or Check().
* - On failure: relevant Status code.
*/
virtual Expected<Detail::TaskHandle> Load(Detail::TaskDesc task) = 0;
/**
* @brief Start an asynchronous dump (device → storage) transfer.
*
* @param task Description of shards to be stored.
* @return Expected<TaskHandle>
* - On success: a task handle that can be passed to Wait() or Check().
* - On failure: relevant Status code.
*/
virtual Expected<Detail::TaskHandle> Dump(Detail::TaskDesc task) = 0;
/**
* @brief Poll for task completion without blocking.
*
* @param taskId Task handle returned by Load() or Dump().
* @return Expected<bool>
* - **true** if the task has finished (successfully or with an error).
* - **false** if the task is still running.
* - Any other value indicates an error in the poll itself.
*/
virtual Expected<bool> Check(Detail::TaskHandle taskId) = 0;
/**
* @brief Block until the specified task completes.
*
* @param taskId Task handle returned by Load() or Dump().
* @return Status::OK on successful completion, otherwise an error code
* describing the failure.
*/
virtual Status Wait(Detail::TaskHandle taskId) = 0;
protected:
/**
* @brief Protected default constructor.
*
* Prevents direct instantiation and enforces derivation.
*/
Store() = default;
};
} // namespace UCPython Base Classes
class Task(ABC):
"""Asynchronous task handle returned by transfer operations.
This is an opaque token that can be polled or awaited.
"""
pass
class UcmKVStoreBase(ABC):
"""Abstract base class for KV-cache-centric storage backends.
A concrete storage vendor must implement this interface to participate in
the unified-cache-management (UCM) system.
"""
def __init__(self, config: Dict[str, object]) -> None:
"""Initialize the store with vendor-specific configuration.
Args:
config: Key-value mapping containing vendor-specific parameters
(e.g., connection string, cache size, compression level).
"""
self.config = config
@abstractmethod
def cc_store(self) -> int:
"""Return a low-level C/C++ pointer to the underlying store.
Returns:
An opaque ``int`` representing the ``Store*`` instance that can
be passed to native code.
"""
pass
@abstractmethod
def lookup(self, block_ids: List[bytes]) -> List[bool]:
"""Check presence of blocks in external storage.
Args:
block_ids: List of vLLM block hashes (raw bytes).
Returns:
A list of booleans, ``True`` if the corresponding block exists in
storage, ``False`` otherwise. The order matches ``block_ids``.
"""
pass
@abstractmethod
def prefetch(self, block_ids: List[bytes]) -> None:
"""Asynchronously prefetch blocks into high-speed cache.
Args:
block_ids: List of vLLM block hashes to prefetch.
"""
pass
@abstractmethod
def load(
self,
block_ids: List[bytes],
shard_index: List[int],
dst_tensor: List[List[torch.Tensor]],
) -> Task:
"""Initiate transfer of KV cache from storage to device.
Args:
block_ids: Hashes of the blocks to load.
shard_index: Shard index for each block.
dst_tensor: Double-list structure where ``dst_tensor[i][j]`` is the
destination PyTorch tensor on device for block ``i``, tensor ``j``.
Returns:
A ``Task`` handle that can be used to check or wait for completion.
"""
pass
@abstractmethod
def dump(
self,
block_ids: List[bytes],
shard_index: List[int],
src_tensor: List[List[torch.Tensor]],
) -> Task:
"""Initiate transfer of KV cache from device to storage.
Args:
block_ids: Hashes of the blocks to write.
shard_index: Shard index for each block.
src_tensor: Double-list structure where ``src_tensor[i][j]`` is the
source PyTorch tensor on device for block ``i``, tensor ``j``.
Returns:
A ``Task`` handle that can be used to check or wait for completion.
"""
pass
@abstractmethod
def fetch_data(
self,
block_ids: List[bytes],
shard_index: List[int],
dst_addr: List[List[int]],
) -> Task:
"""Low-level fetch: copy KV data to device pointers.
Args:
block_ids: Block hashes to load.
shard_index: Shard index for each block.
dst_addr: Double-list of ``int`` pointers (as Python ``int``) to
pre-allocated device buffers.
Returns:
A ``Task`` handle for the asynchronous copy.
"""
pass
@abstractmethod
def dump_data(
self,
block_ids: List[bytes],
shard_index: List[int],
src_addr: List[List[int]],
) -> Task:
"""Low-level dump: copy KV data from device pointers.
Args:
block_ids: Block hashes to store.
shard_index: Shard index for each block.
src_addr: Double-list of ``int`` pointers to device buffers.
Returns:
A ``Task`` handle for the asynchronous copy.
"""
pass
@abstractmethod
def wait(self, task: Task) -> None:
"""Block until the given transfer task completes.
Args:
task: Task handle returned by ``load``, ``dump``, ``fetch_data``,
or ``dump_data``.
"""
pass
@abstractmethod
def check(self, task: Task) -> bool:
"""Non-blocking poll for task completion.
Args:
task: Task handle returned by any transfer method.
Returns:
``True`` if the task has finished, ``False`` if still in-flight.
"""
passUsage Examples
os.environ["UC_LOGGER_LEVEL"] = "debug"
block_size = 1048576
config = {}
config["backends"] = ["."]
config["io_size"] = block_size
config["shard_size"] = block_size
config["block_size"] = block_size
config["transfer_io_direct"] = True
config["transfer_stream_number"] = 16
store = UcmPosixStore(config)
block_num = 1024
block_ids = [secrets.token_bytes(16) for _ in range(block_num)]
founds = store.lookup(block_ids)
assert not all(founds)
shard_idxes = [0 for _ in range(block_num)]
data1 = [[cupy.cuda.alloc_pinned_memory(block_size).ptr] for _ in range(block_num)]
handle = store.dump_data(block_ids, shard_idxes, data1)
store.wait(handle)
founds = store.lookup(block_ids)
assert all(founds)
data2 = [[cupy.cuda.alloc_pinned_memory(block_size).ptr] for _ in range(block_num)]
handle = store.fetch_data(block_ids, shard_idxes, data2)
store.wait(handle)Additional context
No response