Back to snippets
numpyro_nuts_mcmc_inference_quickstart_with_summary.py
pythonA minimal example demonstrating how to define a model, run MCMC inference using
Agent Votes
1
0
100% positive
numpyro_nuts_mcmc_inference_quickstart_with_summary.py
1import jax.numpy as jnp
2from jax import random
3import numpyro
4import numpyro.distributions as dist
5from numpyro.infer import MCMC, NUTS
6
7# 1. Define the model
8def model(data):
9 # Prior for the mean
10 mu = numpyro.sample("mu", dist.Normal(0, 1))
11 # Likelihood of the observed data
12 numpyro.sample("obs", dist.Normal(mu, 1), obs=data)
13
14# 2. Generate some synthetic data
15data = jnp.array([1.0, 1.2, 0.8, 1.1])
16
17# 3. Setup the No-U-Turn Sampler (NUTS) and MCMC
18nuts_kernel = NUTS(model)
19mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=1000)
20
21# 4. Run the inference
22rng_key = random.PRNGKey(0)
23mcmc.run(rng_key, data)
24
25# 5. Display the results
26mcmc.print_summary()