
.. DO NOT EDIT.
.. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY.
.. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE:
.. "examples/covtype.py"
.. LINE NUMBERS ARE GIVEN BELOW.

.. only:: html

    .. note::
        :class: sphx-glr-download-link-note

        :ref:`Go to the end <sphx_glr_download_examples_covtype.py>`
        to download the full example code.

.. rst-class:: sphx-glr-example-title

.. _sphx_glr_examples_covtype.py:


Example: MCMC Methods for Tall Data
===================================

This example illustrates the usages of various MCMC methods which are suitable for tall data:

    - `algo="SA"` uses the sample adaptive MCMC method in [1]
    - `algo="HMCECS"` uses the energy conserving subsampling method in [2]
    - `algo="FlowHMCECS"` utilizes a normalizing flow to neutralize the posterior
      geometry into a Gaussian-like one. Then HMCECS is used to draw the posterior
      samples. Currently, this method gives the best mixing rate among those methods.

**References:**

    1. *Sample Adaptive MCMC*,
       Michael Zhu (2019)
    2. *Hamiltonian Monte Carlo with energy conserving subsampling*,
       Dang, K. D., Quiroz, M., Kohn, R., Minh-Ngoc, T., & Villani, M. (2019)
    3. *NeuTra-lizing Bad Geometry in Hamiltonian Monte Carlo Using Neural Transport*,
       Hoffman, M. et al. (2019)

.. GENERATED FROM PYTHON SOURCE LINES 26-238

.. code-block:: Python


    import argparse
    import time

    import matplotlib.pyplot as plt

    from jax import random
    import jax.numpy as jnp

    import numpyro
    import numpyro.distributions as dist
    from numpyro.examples.datasets import COVTYPE, load_dataset
    from numpyro.infer import HMC, HMCECS, MCMC, NUTS, SA, SVI, Trace_ELBO, init_to_value
    from numpyro.infer.autoguide import AutoBNAFNormal
    from numpyro.infer.reparam import NeuTraReparam


    def _load_dataset():
        _, fetch = load_dataset(COVTYPE, shuffle=False)
        features, labels = fetch()

        # normalize features and add intercept
        features = (features - features.mean(0)) / features.std(0)
        features = jnp.hstack([features, jnp.ones((features.shape[0], 1))])

        # make binary feature
        _, counts = jnp.unique(labels, return_counts=True)
        specific_category = jnp.argmax(counts)
        labels = labels == specific_category

        N, dim = features.shape
        print("Data shape:", features.shape)
        print(
            "Label distribution: {} has label 1, {} has label 0".format(
                labels.sum(), N - labels.sum()
            )
        )
        return features, labels


    def model(data, labels, subsample_size=None):
        dim = data.shape[1]
        coefs = numpyro.sample("coefs", dist.Normal(jnp.zeros(dim), jnp.ones(dim)))
        with numpyro.plate("N", data.shape[0], subsample_size=subsample_size) as idx:
            logits = jnp.dot(data[idx], coefs)
            return numpyro.sample("obs", dist.Bernoulli(logits=logits), obs=labels[idx])


    def benchmark_hmc(args, features, labels):
        rng_key = random.key(1)
        start = time.time()
        # a MAP estimate at the following source
        # https://github.com/google/edward2/blob/master/examples/no_u_turn_sampler/logistic_regression.py#L117
        ref_params = {
            "coefs": jnp.array(
                [
                    +2.03420663e00,
                    -3.53567265e-02,
                    -1.49223924e-01,
                    -3.07049364e-01,
                    -1.00028366e-01,
                    -1.46827862e-01,
                    -1.64167881e-01,
                    -4.20344204e-01,
                    +9.47479829e-02,
                    -1.12681836e-02,
                    +2.64442056e-01,
                    -1.22087866e-01,
                    -6.00568838e-02,
                    -3.79419506e-01,
                    -1.06668741e-01,
                    -2.97053963e-01,
                    -2.05253899e-01,
                    -4.69537191e-02,
                    -2.78072730e-02,
                    -1.43250525e-01,
                    -6.77954629e-02,
                    -4.34899796e-03,
                    +5.90927452e-02,
                    +7.23133609e-02,
                    +1.38526391e-02,
                    -1.24497898e-01,
                    -1.50733739e-02,
                    -2.68872194e-02,
                    -1.80925727e-02,
                    +3.47936489e-02,
                    +4.03552800e-02,
                    -9.98773426e-03,
                    +6.20188080e-02,
                    +1.15002751e-01,
                    +1.32145107e-01,
                    +2.69109547e-01,
                    +2.45785132e-01,
                    +1.19035013e-01,
                    -2.59744357e-02,
                    +9.94279515e-04,
                    +3.39266285e-02,
                    -1.44057125e-02,
                    -6.95222765e-02,
                    -7.52013028e-02,
                    +1.21171586e-01,
                    +2.29205526e-02,
                    +1.47308692e-01,
                    -8.34354162e-02,
                    -9.34122875e-02,
                    -2.97472421e-02,
                    -3.03937674e-01,
                    -1.70958012e-01,
                    -1.59496680e-01,
                    -1.88516974e-01,
                    -1.20889175e00,
                ]
            )
        }
        if args.algo == "HMC":
            step_size = jnp.sqrt(0.5 / features.shape[0])
            trajectory_length = step_size * args.num_steps
            kernel = HMC(
                model,
                step_size=step_size,
                trajectory_length=trajectory_length,
                adapt_step_size=False,
                dense_mass=args.dense_mass,
            )
            subsample_size = None
        elif args.algo == "NUTS":
            kernel = NUTS(model, dense_mass=args.dense_mass)
            subsample_size = None
        elif args.algo == "HMCECS":
            subsample_size = 1000
            inner_kernel = NUTS(
                model,
                init_strategy=init_to_value(values=ref_params),
                dense_mass=args.dense_mass,
            )
            # note: if num_blocks=100, we'll update 10 index at each MCMC step
            # so it took 50000 MCMC steps to iterative the whole dataset
            kernel = HMCECS(
                inner_kernel, num_blocks=100, proxy=HMCECS.taylor_proxy(ref_params)
            )
        elif args.algo == "SA":
            # NB: this kernel requires large num_warmup and num_samples
            # and running on GPU is much faster than on CPU
            kernel = SA(
                model, adapt_state_size=1000, init_strategy=init_to_value(values=ref_params)
            )
            subsample_size = None
        elif args.algo == "FlowHMCECS":
            subsample_size = 1000
            guide = AutoBNAFNormal(model, num_flows=1, hidden_factors=[8])
            svi = SVI(model, guide, numpyro.optim.Adam(0.01), Trace_ELBO())
            svi_result = svi.run(random.key(2), 2000, features, labels)
            params, losses = svi_result.params, svi_result.losses
            plt.plot(losses)
            plt.show()

            neutra = NeuTraReparam(guide, params)
            neutra_model = neutra.reparam(model)
            neutra_ref_params = {"auto_shared_latent": jnp.zeros(55)}
            # no need to adapt mass matrix if the flow does a good job
            inner_kernel = NUTS(
                neutra_model,
                init_strategy=init_to_value(values=neutra_ref_params),
                adapt_mass_matrix=False,
            )
            kernel = HMCECS(
                inner_kernel, num_blocks=100, proxy=HMCECS.taylor_proxy(neutra_ref_params)
            )
        else:
            raise ValueError("Invalid algorithm, either 'HMC', 'NUTS', or 'HMCECS'.")
        mcmc = MCMC(kernel, num_warmup=args.num_warmup, num_samples=args.num_samples)
        mcmc.run(rng_key, features, labels, subsample_size, extra_fields=("accept_prob",))
        print("Mean accept prob:", jnp.mean(mcmc.get_extra_fields()["accept_prob"]))
        mcmc.print_summary(exclude_deterministic=False)
        print("\nMCMC elapsed time:", time.time() - start)


    def main(args):
        features, labels = _load_dataset()
        benchmark_hmc(args, features, labels)


    if __name__ == "__main__":
        assert numpyro.__version__.startswith("0.21.0")
        parser = argparse.ArgumentParser(description="parse args")
        parser.add_argument(
            "-n", "--num-samples", default=1000, type=int, help="number of samples"
        )
        parser.add_argument(
            "--num-warmup", default=1000, type=int, help="number of warmup steps"
        )
        parser.add_argument(
            "--num-steps", default=10, type=int, help='number of steps (for "HMC")'
        )
        parser.add_argument("--num-chains", nargs="?", default=1, type=int)
        parser.add_argument(
            "--algo",
            default="HMCECS",
            type=str,
            help='whether to run "HMC", "NUTS", "HMCECS", "SA" or "FlowHMCECS"',
        )
        parser.add_argument("--dense-mass", action="store_true")
        parser.add_argument("--x64", action="store_true")
        parser.add_argument("--device", default="cpu", type=str, help='use "cpu" or "gpu".')
        args = parser.parse_args()

        numpyro.set_platform(args.device)
        numpyro.set_host_device_count(args.num_chains)
        if args.x64:
            numpyro.enable_x64()

        main(args)


.. _sphx_glr_download_examples_covtype.py:

.. only:: html

  .. container:: sphx-glr-footer sphx-glr-footer-example

    .. container:: sphx-glr-download sphx-glr-download-jupyter

      :download:`Download Jupyter notebook: covtype.ipynb <covtype.ipynb>`

    .. container:: sphx-glr-download sphx-glr-download-python

      :download:`Download Python source code: covtype.py <covtype.py>`

    .. container:: sphx-glr-download sphx-glr-download-zip

      :download:`Download zipped: covtype.zip <covtype.zip>`


.. only:: html

 .. rst-class:: sphx-glr-signature

    `Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_
