Comparison with MCMC

In this notebook, we will compare two parameter estimation methods: SBI and traditional MCMC.

The model

Our model of choice is a Gaussian mixture, which we describe below:

The parameters \(\boldsymbol{\theta} \in \mathbb{R}^n\) are sampled independently from a uniform distribution,

\[ \theta_i \sim \mathcal{U}([-1, 1]),\]

for \(i \in \{1, \ldots, n\}\).

The data \(\boldsymbol{x} \in \mathbb{R}^n\) are generated as follows:

\[ \boldsymbol{x} \sim 0.5 \mathcal{N}(\mu=\boldsymbol{\theta}, \sigma_1 \boldsymbol{I}_n) + 0.5 \mathcal{N}(\mu=\boldsymbol{\theta}, \sigma_2 \boldsymbol{I}_n),\]

where \(\sigma_1 \gg \sigma_2 > 0\). We fix \(\sigma_1 = 1, \sigma_2 = 0.01\).

import sys


import torch
import matplotlib.pyplot as plt
import arviz as az
from sbi import analysis as analysis
from sbi import utils as utils

from sbisandbox.benchmarks import GaussianMixtureBenchmark
from sbisandbox.runners import SNPERunner, MCMCRunner

ndim = 2
benchmark = GaussianMixtureBenchmark(ndim)
num_simulations = 1000
seed = 1991
theta, x = benchmark.get_observations(
    num_simulations, seed=seed, simulation_batch_size=None

Plotting histograms:

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))

theta_labels = [r"$\theta_1$", r"$\theta_2$"]
x_labels = ["$x_1$", "$x_2$"]

az.plot_kde(theta[:, 0], theta[:, 1], contour=False, ax=ax1)
ax1.set(xlabel=theta_labels[0], ylabel=theta_labels[1])

az.plot_kde(x[:, 0], x[:, 1], contour=False, ax=ax2)
ax2.set(xlabel=x_labels[0], ylabel=x_labels[1])
Choosing fiducial values:

fiducial_theta = torch.zeros(ndim)
num_samples = 10000
num_simulations = 10000
x0 = benchmark.simulator(fiducial_theta.unsqueeze(0))

Reference posterior samples

We can sample from the analytical posterior in a straightforward manner:

  1. Sample \(u \sim \text{Bernoulli}(p=0.5)\)
  2. If \(u=1\), choose \(\sigma=\sigma_1\), else choose \(\sigma=\sigma_2\)
  3. Sample \(\boldsymbol{x} \sim \mathcal{N}(\boldsymbol{\theta}, \sigma I_n)\)
true_samples = benchmark.get_posterior_samples((1, num_samples), x=x0)
def get_data(labels, samples, warmup: int = 0, thin: int = 1):
    data = (samples[warmup::thin, :, 0], samples[warmup::thin, :, 1])
    return dict(zip(labels, data))

data_from_true_posterior = get_data(theta_labels, true_samples)

az.plot_pair(data_from_true_posterior, var_names=theta_labels, marginals=True)
We indeed see two populations of samples with different variance. Let us see the data generated from the posterior samples as a sanity check:

x_from_posterior_samples = benchmark.simulator(true_samples.squeeze())
    x_from_posterior_samples[:, 0], x_from_posterior_samples[:, 1], contour=False
We run MCMC with pyro's implementation of the NUTS sampler, which gets called under the hood from the MCMCRunner class:


# Warmup will take num_samples / 2 steps as well
mcmc = MCMCRunner(benchmark=benchmark, seed=seed)
mcmc_samples = mcmc.sample(num_samples // 2, x=x0)
Visualizing the results in a corner plot:

data_from_mcmc = get_data(theta_labels, mcmc_samples["theta"])
az.plot_pair(data_from_mcmc, var_names=theta_labels, marginals=True)
We observe that the chains do not seem to explore the whole posterior mass along the first dimension \(\theta_1\). Perhaps we need more samples?


We choose NPE as the SBI counterpart so as to avoid running a complementary MCMC step. We use neural spline flows as they seem to have the best performance on this task.

npe = SNPERunner(benchmark=benchmark, seed=seed, density_estimator="nsf")
num_simulations = 10000
training_kwargs = {
    "show_train_summary": True,
    "training_batch_size": 128,
    "use_combined_loss": True,
    "learning_rate": 2e-4,
truncate_at = 1e-3

    num_simulations, x_0=x0, training_kwargs=training_kwargs, truncate_at=truncate_at

samples_from_npe = npe.sample(num_samples, x=x0)
data_from_npe = get_data(theta_labels, samples_from_npe.unsqueeze(0))
az.plot_pair(data_from_npe, var_names=theta_labels, marginals=True)
We do see some posterior leakage, but it does not seem to be significant with respect to the total number of samples.

Comparing results

Visualizing marginals:

    [data_from_npe, data_from_mcmc],
    data_labels=["NPE", "MCMC"],
    [data_from_npe, data_from_true_posterior],
    data_labels=["NPE", "Reference"],
    [data_from_mcmc, data_from_true_posterior],
    data_labels=["MCMC", "Reference"],
