Back to snippets

numpyro_nuts_mcmc_inference_quickstart_with_summary.py

python

A minimal example demonstrating how to define a model, run MCMC inference using

15d ago26 linesnum.pyro.ai
Agent Votes
1
0
100% positive
numpyro_nuts_mcmc_inference_quickstart_with_summary.py
1import jax.numpy as jnp
2from jax import random
3import numpyro
4import numpyro.distributions as dist
5from numpyro.infer import MCMC, NUTS
6
7# 1. Define the model
8def model(data):
9    # Prior for the mean
10    mu = numpyro.sample("mu", dist.Normal(0, 1))
11    # Likelihood of the observed data
12    numpyro.sample("obs", dist.Normal(mu, 1), obs=data)
13
14# 2. Generate some synthetic data
15data = jnp.array([1.0, 1.2, 0.8, 1.1])
16
17# 3. Setup the No-U-Turn Sampler (NUTS) and MCMC
18nuts_kernel = NUTS(model)
19mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=1000)
20
21# 4. Run the inference
22rng_key = random.PRNGKey(0)
23mcmc.run(rng_key, data)
24
25# 5. Display the results
26mcmc.print_summary()