Skip to content

markus7800/UPIX

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

UPIX: Universal Programmable Inference in JAX

... brings back stochastic control flow to probabilistic modelling in JAX.

Universal probabilistic programming languages (PPL) like Pyro or Gen enable the user to specify models with stochastic support. This means that control flow and array shapes are allowed to depend on the values sampled during execution. This is fundamentally incompatible with JIT-compilation in JAX. Thus, probabilistic programming systems built on top of JAX like NumPyro are restricted to models with static support, i.e. they disallow Python control flow. UPIX realises the Divide-Conquer-Combine (DCC) approach [1] as a framework which brings back JIT-compilation for universal PPLs and enables running inference on CPUs, GPUs or TPUs.

At its core the DCC approach splits up a model with stochastic support into a potentially infinite number of sub-models with static support. In UPIX this is realised with a custom JAX interpreter which records and compiles the probabilistic program for each choice of branching decisions (all instances where an abstract JAX array tracer is made concrete). Thus, a program specified in our universal PPL is split up into multiple JIT-compilable straigt-line-programs (SLPs).

UPIX provides constructs to for programmable inference: we enable the user to customise

  • how the model is split up in the divide step
  • how the inference is run in the conquer step
  • how the approximations of the sub-models are combined

Usage

This is a work in progress. Instructions are coming soon.

Install options: [cpu], [cuda], and [tpu].

For now, we refer to the example programs in the evaluation folder.

We recommend using uv with uv run -p python3.13 --frozen --extra=cuda script_to_run.py.

Example

import jax
from upix.core import *

@model
def pedestrian():
    start = sample("start", dist.Uniform(0.,3.))
    position = start
    distance = 0.
    t = 0
    while (position > 0) & (distance < 10):
        t += 1
        step = sample(f"step_{t}", dist.Uniform(-1.,1.))
        position += step
        distance += jax.lax.abs(step)
    sample("obs", dist.Normal(distance, 0.1), observed=1.1)
    return start

Above we have implemented the Pedestrian model from Mak et al. [2]. The syntax resemples NumPyro with the criticial difference that the while loop depends on position and distance, two quantities that are computed from the random variables f"step_{t}". Thus, the number of while loop iteration depends on the values sampled during execution, resulting in a stochastic support. Each straight-line program SLP corresponds to a sub-model where the loop is run for a fixed number of times.

class DCCConfig(MCMCDCC[T]):
    def get_MCMC_inference_regime(self, slp: SLP) -> MCMCRegime:
        regime = MCMCSteps(
            MCMCStep(Variables("start"), RW(lambda x: dist.Uniform(jnp.zeros_like(x),3))),
            MCMCStep(Variables(r"step_\d+"), DHMC(50, 0.05, 0.15)),
        )
    def initialise_active_slps(self, active_slps: List[SLP], inactive_slps: List[SLP], rng_key: jax.Array):
        ...
    def update_active_slps(self, active_slps: List[SLP], inactive_slps: List[SLP], rng_key: PRNGKey):
        ...

dcc_obj = DCCConfig(m, verbose=2,
    parallelisation=get_parallelisation_config(args),
    init_n_samples=250,
    init_estimate_weight_n_samples=2**20,
    mcmc_n_chains=8,
    mcmc_n_samples_per_chain=25_000,
    estimate_weight_n_samples=2**23,
    max_iterations=1,
)

result = dcc_obj.run(jax.random.key(0))

plot_histogram_by_slp(result, "start")


Above, we sketch a MCMC-DCC inference algorithm for the Pedestrian model. In get_MCMC_inference_regime, we customise the MCMC kernel used for each SLP. We use a Metropolis-Hastings kernel RW for the start variable, which simply proposed a uniform value from 0 to 3. For the step variables, we apply discontinuous HMC DHMC, a variant of Hamiltonian Monte Carlo which can deal with discontinuities.

In initialise_active_slps and update_active_slps, we may specify how we find SLPs and for which of them inference should be run. In the former, we simply draw 250 samples from the program prior which instantiates SLPs, then weigh them with importance sampling, and make the most probable SLPs active. For this simple model, in update_active_slps we simply make all SLPs inactive after running inference once. For more complex model, we may implement more sophisticated routines here that discard low probability SLPs and slighlty mutate high probability SLPs resulting in multiple DCC phases.

Lastly, building on features of JAX, we allow the user to customise parallelisation and vectorisation for inference. In UPIX, you can run inference for multiple SLPs in parallel on multi-core CPUs or on different accelerator devices like GPUs or TPUs. But we also have the option to use mutliple devices to accelerate inference for a single SLP. This is especially useful for inference routines which can be efficiently parallelised like many-chain MCMC, multiple-run VI, or SMC, see the scaling section below.

On the right, you can see the inference result as an approximation to the posterior of the variable start. We have histograms for each SLP, i.e. for each number of steps / loop iteration. On the right of the historgrams you can see the weight that was estimated for each SLP. Combining the samples according to those weights results in the historgram on the bottom which approximates the posterior of start in the full model with stochastic support. We can see that this approximation is close to the ground truth.


Implemented DCC algorithms

Scaling Inference on multiple XLA devices

References

[1] Zhou, Yuan, et al. "Divide, conquer, and combine: a new inference strategy for probabilistic programs with stochastic support." International Conference on Machine Learning. PMLR, 2020.

[2] Mak, Carol, Fabian Zaiser, and Luke Ong. "Nonparametric hamiltonian monte carlo." International Conference on Machine Learning. PMLR, 2021.

[3] Reichelt, Tim, Luke Ong, and Thomas Rainforth. "Rethinking variational inference for probabilistic programs with stochastic support." Advances in Neural Information Processing Systems 35 (2022): 15160-15175.

Reproducing Paper Results

Setup

Manual

To run the experiments in on your machine you need uv, Julia 1.9, and a C++ compiler.

Install uv, e.g. with curl -LsSf https://astral.sh/uv/install.sh | sh.

Install julia 1.9, e.g. with curl -fsSL https://install.julialang.org | sh -s -- --yes --default-channel=1.9.

Docker

If you want to run the experiments in a docker container instead, build and run it with:

docker build . -t upix
docker run -it --name upix --rm upix

Make sure to make all CPUs available in the container.
To make GPUs in the container available, see https://docs.docker.com/engine/containers/resource_constraints/#gpu .
Runtimes using the docker container may be different compared to running locally.

Test the Setup

Run example on CPU

uv run --frozen -p python3.13 --extra=cpu evaluation/pedestrian/run_example.py sequential vmap_local

Run example on CUDA GPU (if available)

uv run --frozen -p python3.13 --extra=cuda evaluation/pedestrian/run_example.py sequential vmap_local

Run example on TPU (if available)

uv run --frozen -p python3.13 --extra=tpu evaluation/pedestrian/run_example.py sequential vmap_local

Section 4: Example DCC

Run following commands from the root directory to reproduce the experiments from Section 4.

Experiments where run on a M2 Pro Macbook (without Docker).

Section 4.1: MCMC - Pedestrian Model

Run NP-DHMC baseline (with 8 parallel processes)

cd evaluation/pedestrian/nonparametric-hmc
uv run -p python3.10 --no-project --with-requirements=requirements.txt pedestrian.py NP-DHMC 8 1000 100 -n_processes 8  --store_samples
uv run -p python3.10 --no-project --with-requirements=requirements.txt check_results.py
cd ../../..

Run UPIX-MCMC-DCC (with 8 CPU devices)

uv run -p python3.13 --frozen --extra=cpu evaluation/pedestrian/run_comp.py sequential pmap --show_plots -host_device_count 8

Section 4.2: SDVI - Gaussian Process Model

Run SDVI baseline (original implementation by Reichelt et al. 2022, with 10 parallel processes)

bash evaluation/gp/sdvi/run_comp.sh 10

Run UPIX-SDVI (with 10 parallel processes)

uv run -p python3.13 --frozen --extra=cpu --with pandas evaluation/gp/run_comp_vi.py cpu_multiprocess vmap_local -num_workers 10

Section 4.3: RJMCMC - Gaussian Mixture Model

If you do not use the docker image, install the julia packages

julia --project=evaluation/gmm/gen -e "import Pkg; Pkg.instantiate()

Run RJMCMC Gen baseline (with 8 threads)

julia -t 8 --project=evaluation/gmm/gen evaluation/gmm/gen/gmm.jl 8 25000

Run UPIX-RJMCMC-DCC (with 8 CPU devices)

uv run -p python3.13 --frozen --extra=cpu evaluation/gmm/run_comp.py sequential pmap -host_device_count 8

Section 4.4: SMC - Gaussian Process Model

If you do not use the docker image, install the julia packages

julia --project=evaluation/gp/autogp -e "import Pkg; Pkg.instantiate()"

Run AutoGP baseline (with 10 threads)

julia -t 10 --project=evaluation/gp/autogp evaluation/gp/autogp/main.jl 100 false

Run UPIX-SMC-DCC (with 10 CPU devices)

uv run -p python3.13 --frozen --extra=cpu --with=pandas evaluation/gp/run_comp_smc.py sequential smap_local -host_device_count 10 --show_plots

Section 4.5: VE - Urn Model

If you do not use the docker image, compile the Swift compiler. E.g. on Linux

apt-get install -y g++ cmake libopenblas-dev liblapack-dev libarmadillo-dev
make compile -C evaluation/urn/milch/swift/

Compile and run the BLOG baseline

cd evaluation/urn/milch 
python3 compile.py
python3 run.py
cd ../../..

Run UPIX-VE-DCC

uv run -p python3.13 --frozen --extra=cpu evaluation/urn/run_comp.py sequential vmap_local 20 --jit_inf

Section 5: Scaling Experiments

We have implemented scripts to launch the run_scale.py scripts for each model with varying hardware and workload. Set the $platform, $ndevices = cpu | cuda arguments depending on your hardware. $ndevices has to be a power of 2. If you do not have a CPU with a processor count that is a power of 2, then you may prefix the following commands with taskset to restrict the available CPUs, e.g. taskset -c 0-7 python3 experiments/... to use 8 CPUs (only works on Linux). We ran our experiments on a Linux machine with 64 CPU cores and 8 48GB NVIDIDA GPUs (without Docker) using following configurations:

($platform, $ndevices) =
    (cpu, 8) | (cpu, 16) | (cpu, 32) | (cpu, 64) |
    (cuda, 1) | (cuda, 2) | (cuda, 4) | (cuda, 8)

The script arguments following $platform, $ndevices set the workload range for each experimeent in log2 base.

For instance

python3 experiments/runners/run_pedestrian_scale.py cuda 1 0 20 sequential

runs the scaling experiment for the Pedestrian model with number of MCMC chains varying from 2^0=1 to 2^20=1048576 on a single GPU.

If you have less powerful hardware or do not want to run long experiments (they run up to 5 hours), you may lower the workloads.

For instance, with arguments cuda 1 0 19 sequential the experiment should take half the time, with cuda 1 0 18 sequential it should take a quarter of the time, and so on.

Scaling MCMC - Pedestrian Model

python3 experiments/runners/run_pedestrian_scale.py $platform $ndevices 0 20 sequential

Scaling RJMCMC - Gaussian Mixture Model

python3 experiments/runners/run_gmm_scale.py $platform $ndevices 0 18 sequential

Scaling SDVI - Gaussian Process Model

python3 experiments/runners/run_gp_vi_scale.py $platform $ndevices 0 14 sequential

Scaling SMC - Gaussian Process Model

python3 experiments/runners/run_gp_smc_scale.py $platform $ndevices 0 15 sequential

MCMC Reference NP-DHMC - Pedestrian Model

python3 experiments/runners/run_npdhmc_scale.py $ndevices 0 20

RJMCMC Reference Gen RJMCMC - Gaussian Mixture Model

python3 experiments/runners/run_rjmcmc_scale.py $ndevices 0 15

SDVI Reference Original SDVI - Gaussian Process Model

python3 experiments/runners/run_sdvi_scale.py 1 0 3

SMC Reference AutoGP - Gaussian Process Model

python3 experiments/runners/run_autogp_scale.py $ndevices 0 15