... brings back stochastic control flow to probabilistic modelling in JAX.
Universal probabilistic programming languages (PPL) like Pyro or Gen enable the user to specify models with stochastic support. This means that control flow and array shapes are allowed to depend on the values sampled during execution. This is fundamentally incompatible with JIT-compilation in JAX. Thus, probabilistic programming systems built on top of JAX like NumPyro are restricted to models with static support, i.e. they disallow Python control flow. UPIX realises the Divide-Conquer-Combine (DCC) approach [1] as a framework which brings back JIT-compilation for universal PPLs and enables running inference on CPUs, GPUs or TPUs.
At its core the DCC approach splits up a model with stochastic support into a potentially infinite number of sub-models with static support. In UPIX this is realised with a custom JAX interpreter which records and compiles the probabilistic program for each choice of branching decisions (all instances where an abstract JAX array tracer is made concrete). Thus, a program specified in our universal PPL is split up into multiple JIT-compilable straigt-line-programs (SLPs).
UPIX provides constructs to for programmable inference: we enable the user to customise
- how the model is split up in the divide step
- how the inference is run in the conquer step
- how the approximations of the sub-models are combined
This is a work in progress. Instructions are coming soon.
Install options: [cpu], [cuda], and [tpu].
For now, we refer to the example programs in the evaluation folder.
We recommend using uv with uv run -p python3.13 --frozen --extra=cuda script_to_run.py.
import jax
from upix.core import *
@model
def pedestrian():
start = sample("start", dist.Uniform(0.,3.))
position = start
distance = 0.
t = 0
while (position > 0) & (distance < 10):
t += 1
step = sample(f"step_{t}", dist.Uniform(-1.,1.))
position += step
distance += jax.lax.abs(step)
sample("obs", dist.Normal(distance, 0.1), observed=1.1)
return startAbove we have implemented the Pedestrian model from Mak et al. [2].
The syntax resemples NumPyro with the criticial difference that the while loop depends on position and distance, two quantities that are computed from the random variables f"step_{t}".
Thus, the number of while loop iteration depends on the values sampled during execution, resulting in a stochastic support.
Each straight-line program SLP corresponds to a sub-model where the loop is run for a fixed number of times.
class DCCConfig(MCMCDCC[T]):
def get_MCMC_inference_regime(self, slp: SLP) -> MCMCRegime:
regime = MCMCSteps(
MCMCStep(Variables("start"), RW(lambda x: dist.Uniform(jnp.zeros_like(x),3))),
MCMCStep(Variables(r"step_\d+"), DHMC(50, 0.05, 0.15)),
)
def initialise_active_slps(self, active_slps: List[SLP], inactive_slps: List[SLP], rng_key: jax.Array):
...
def update_active_slps(self, active_slps: List[SLP], inactive_slps: List[SLP], rng_key: PRNGKey):
...
dcc_obj = DCCConfig(m, verbose=2,
parallelisation=get_parallelisation_config(args),
init_n_samples=250,
init_estimate_weight_n_samples=2**20,
mcmc_n_chains=8,
mcmc_n_samples_per_chain=25_000,
estimate_weight_n_samples=2**23,
max_iterations=1,
)
result = dcc_obj.run(jax.random.key(0))
plot_histogram_by_slp(result, "start")Above, we sketch a MCMC-DCC inference algorithm for the Pedestrian model.
In get_MCMC_inference_regime, we customise the MCMC kernel used for each SLP.
We use a Metropolis-Hastings kernel RW for the start variable, which simply proposed a uniform value from 0 to 3.
For the step variables, we apply discontinuous HMC DHMC, a variant of Hamiltonian Monte Carlo which can deal with discontinuities.
In initialise_active_slps and update_active_slps, we may specify how we find SLPs and for which of them inference should be run.
In the former, we simply draw 250 samples from the program prior which instantiates SLPs, then weigh them with importance sampling, and make the most probable SLPs active.
For this simple model, in update_active_slps we simply make all SLPs inactive after running inference once.
For more complex model, we may implement more sophisticated routines here that discard low probability SLPs and slighlty mutate high probability SLPs resulting in multiple DCC phases.
Lastly, building on features of JAX, we allow the user to customise parallelisation and vectorisation for inference. In UPIX, you can run inference for multiple SLPs in parallel on multi-core CPUs or on different accelerator devices like GPUs or TPUs. But we also have the option to use mutliple devices to accelerate inference for a single SLP. This is especially useful for inference routines which can be efficiently parallelised like many-chain MCMC, multiple-run VI, or SMC, see the scaling section below.
On the right, you can see the inference result as an approximation to the posterior of the variable start.
We have histograms for each SLP, i.e. for each number of steps / loop iteration.
On the right of the historgrams you can see the weight that was estimated for each SLP.
Combining the samples according to those weights results in the historgram on the bottom which approximates the posterior of start in the full model with stochastic support.
We can see that this approximation is close to the ground truth.
- Markov-Chain-Monte-Carlo DCC. See the pedestrian example.
- Variational Inference DCC (SDVI [3]). See the Gaussian process example.
- Reversible Jump / Involutive MCMC DCC. See the Gaussian mixture model example.
- Sequential Monte Carlo DCC. See the Gaussian process example.
- Variable Elimination DCC. See the Urn example.
[1] Zhou, Yuan, et al. "Divide, conquer, and combine: a new inference strategy for probabilistic programs with stochastic support." International Conference on Machine Learning. PMLR, 2020.
[2] Mak, Carol, Fabian Zaiser, and Luke Ong. "Nonparametric hamiltonian monte carlo." International Conference on Machine Learning. PMLR, 2021.
[3] Reichelt, Tim, Luke Ong, and Thomas Rainforth. "Rethinking variational inference for probabilistic programs with stochastic support." Advances in Neural Information Processing Systems 35 (2022): 15160-15175.
To run the experiments in on your machine you need uv, Julia 1.9, and a C++ compiler.
Install uv, e.g. with curl -LsSf https://astral.sh/uv/install.sh | sh.
Install julia 1.9, e.g. with curl -fsSL https://install.julialang.org | sh -s -- --yes --default-channel=1.9.
If you want to run the experiments in a docker container instead, build and run it with:
docker build . -t upix
docker run -it --name upix --rm upix
Make sure to make all CPUs available in the container.
To make GPUs in the container available, see https://docs.docker.com/engine/containers/resource_constraints/#gpu .
Runtimes using the docker container may be different compared
to running locally.
Run example on CPU
uv run --frozen -p python3.13 --extra=cpu evaluation/pedestrian/run_example.py sequential vmap_local
Run example on CUDA GPU (if available)
uv run --frozen -p python3.13 --extra=cuda evaluation/pedestrian/run_example.py sequential vmap_local
Run example on TPU (if available)
uv run --frozen -p python3.13 --extra=tpu evaluation/pedestrian/run_example.py sequential vmap_local
Run following commands from the root directory to reproduce the experiments from Section 4.
Experiments where run on a M2 Pro Macbook (without Docker).
Run NP-DHMC baseline (with 8 parallel processes)
cd evaluation/pedestrian/nonparametric-hmc
uv run -p python3.10 --no-project --with-requirements=requirements.txt pedestrian.py NP-DHMC 8 1000 100 -n_processes 8 --store_samples
uv run -p python3.10 --no-project --with-requirements=requirements.txt check_results.py
cd ../../..
Run UPIX-MCMC-DCC (with 8 CPU devices)
uv run -p python3.13 --frozen --extra=cpu evaluation/pedestrian/run_comp.py sequential pmap --show_plots -host_device_count 8
Run SDVI baseline (original implementation by Reichelt et al. 2022, with 10 parallel processes)
bash evaluation/gp/sdvi/run_comp.sh 10
Run UPIX-SDVI (with 10 parallel processes)
uv run -p python3.13 --frozen --extra=cpu --with pandas evaluation/gp/run_comp_vi.py cpu_multiprocess vmap_local -num_workers 10
If you do not use the docker image, install the julia packages
julia --project=evaluation/gmm/gen -e "import Pkg; Pkg.instantiate()
Run RJMCMC Gen baseline (with 8 threads)
julia -t 8 --project=evaluation/gmm/gen evaluation/gmm/gen/gmm.jl 8 25000
Run UPIX-RJMCMC-DCC (with 8 CPU devices)
uv run -p python3.13 --frozen --extra=cpu evaluation/gmm/run_comp.py sequential pmap -host_device_count 8
If you do not use the docker image, install the julia packages
julia --project=evaluation/gp/autogp -e "import Pkg; Pkg.instantiate()"
Run AutoGP baseline (with 10 threads)
julia -t 10 --project=evaluation/gp/autogp evaluation/gp/autogp/main.jl 100 false
Run UPIX-SMC-DCC (with 10 CPU devices)
uv run -p python3.13 --frozen --extra=cpu --with=pandas evaluation/gp/run_comp_smc.py sequential smap_local -host_device_count 10 --show_plots
If you do not use the docker image, compile the Swift compiler. E.g. on Linux
apt-get install -y g++ cmake libopenblas-dev liblapack-dev libarmadillo-dev
make compile -C evaluation/urn/milch/swift/
Compile and run the BLOG baseline
cd evaluation/urn/milch
python3 compile.py
python3 run.py
cd ../../..
Run UPIX-VE-DCC
uv run -p python3.13 --frozen --extra=cpu evaluation/urn/run_comp.py sequential vmap_local 20 --jit_inf
We have implemented scripts to launch the run_scale.py scripts for each model with varying hardware and workload.
Set the $platform, $ndevices = cpu | cuda arguments depending on your hardware.
$ndevices has to be a power of 2.
If you do not have a CPU with a processor count that is a power of 2, then you may prefix the following commands with taskset to restrict the available CPUs, e.g. taskset -c 0-7 python3 experiments/... to use 8 CPUs (only works on Linux).
We ran our experiments on a Linux machine with 64 CPU cores and 8 48GB NVIDIDA GPUs (without Docker) using following configurations:
($platform, $ndevices) =
(cpu, 8) | (cpu, 16) | (cpu, 32) | (cpu, 64) |
(cuda, 1) | (cuda, 2) | (cuda, 4) | (cuda, 8)
The script arguments following $platform, $ndevices set the workload range for each experimeent in log2 base.
For instance
python3 experiments/runners/run_pedestrian_scale.py cuda 1 0 20 sequential
runs the scaling experiment for the Pedestrian model with number of MCMC chains varying from 2^0=1 to 2^20=1048576 on a single GPU.
If you have less powerful hardware or do not want to run long experiments (they run up to 5 hours), you may lower the workloads.
For instance, with arguments cuda 1 0 19 sequential the experiment should take half the time, with cuda 1 0 18 sequential it should take a quarter of the time, and so on.
python3 experiments/runners/run_pedestrian_scale.py $platform $ndevices 0 20 sequential
python3 experiments/runners/run_gmm_scale.py $platform $ndevices 0 18 sequential
python3 experiments/runners/run_gp_vi_scale.py $platform $ndevices 0 14 sequential
python3 experiments/runners/run_gp_smc_scale.py $platform $ndevices 0 15 sequential
python3 experiments/runners/run_npdhmc_scale.py $ndevices 0 20
python3 experiments/runners/run_rjmcmc_scale.py $ndevices 0 15
python3 experiments/runners/run_sdvi_scale.py 1 0 3
python3 experiments/runners/run_autogp_scale.py $ndevices 0 15

