
.. DO NOT EDIT.
.. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY.
.. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE:
.. "examples/prodlda.py"
.. LINE NUMBERS ARE GIVEN BELOW.

.. only:: html

    .. note::
        :class: sphx-glr-download-link-note

        :ref:`Go to the end <sphx_glr_download_examples_prodlda.py>`
        to download the full example code.

.. rst-class:: sphx-glr-example-title

.. _sphx_glr_examples_prodlda.py:


Example: ProdLDA with Flax
==========================

In this example, we will follow [1] to implement the ProdLDA topic model from
Autoencoding Variational Inference For Topic Models by Akash Srivastava and Charles
Sutton [2]. This model returns consistently better topics than vanilla LDA and trains
much more quickly. Furthermore, it does not require a custom inference algorithm that
relies on complex mathematical derivations. This example also serves as an
introduction to Flax modules in NumPyro.

Note that unlike [1, 2], this implementation uses a Dirichlet prior directly rather
than approximating it with a softmax-normal distribution.

For the interested reader, a nice extension of this model is the CombinedTM model [3]
which utilizes a pre-trained sentence transformer (like https://www.sbert.net/) to
generate a better representation of the encoded latent vector.

**References:**
    1. http://pyro.ai/examples/prodlda.html
    2. Akash Srivastava, & Charles Sutton. (2017). Autoencoding Variational Inference
       For Topic Models.
    3. Federico Bianchi, Silvia Terragni, and Dirk Hovy (2021), "Pre-training is a Hot
       Topic: Contextualized Document Embeddings Improve Topic Coherence"
       (https://arxiv.org/abs/2004.03974)

.. image:: ../_static/img/examples/prodlda.png
    :align: center

.. GENERATED FROM PYTHON SOURCE LINES 33-256

.. code-block:: Python


    import argparse

    import matplotlib.pyplot as plt
    import pandas as pd
    from sklearn.datasets import fetch_20newsgroups
    from sklearn.feature_extraction.text import CountVectorizer
    from wordcloud import WordCloud

    import flax.linen as nn
    import jax
    from jax import device_put, random
    import jax.numpy as jnp

    import numpyro
    from numpyro.contrib.module import flax_module
    import numpyro.distributions as dist
    from numpyro.infer import SVI, TraceMeanField_ELBO


    class FlaxEncoder(nn.Module):
        vocab_size: int
        num_topics: int
        hidden: int
        dropout_rate: float

        @nn.compact
        def __call__(self, inputs, is_training):
            h = nn.softplus(nn.Dense(self.hidden)(inputs))
            h = nn.softplus(nn.Dense(self.hidden)(h))
            h = nn.Dropout(self.dropout_rate, deterministic=not is_training)(h)
            h = nn.Dense(self.num_topics)(h)

            log_concentration = nn.BatchNorm(
                use_bias=False,
                use_scale=False,
                momentum=0.9,
                use_running_average=not is_training,
            )(h)
            return jnp.exp(log_concentration)


    class FlaxDecoder(nn.Module):
        vocab_size: int
        dropout_rate: float

        @nn.compact
        def __call__(self, inputs, is_training):
            h = nn.Dropout(self.dropout_rate, deterministic=not is_training)(inputs)
            h = nn.Dense(self.vocab_size, use_bias=False)(h)
            return nn.BatchNorm(
                use_bias=False,
                use_scale=False,
                momentum=0.9,
                use_running_average=not is_training,
            )(h)


    def model(docs, hyperparams, is_training=False):
        decoder = flax_module(
            "decoder",
            FlaxDecoder(hyperparams["vocab_size"], hyperparams["dropout_rate"]),
            input_shape=(1, hyperparams["num_topics"]),
            # ensure PRNG key is made available to dropout layers
            apply_rng=["dropout"],
            # indicate mutable state due to BatchNorm layers
            mutable=["batch_stats"],
            # to ensure proper initialisation of BatchNorm we must
            # initialise with is_training=True
            is_training=True,
        )

        with numpyro.plate(
            "documents", docs.shape[0], subsample_size=hyperparams["batch_size"]
        ):
            batch_docs = numpyro.subsample(docs, event_dim=1)
            theta = numpyro.sample(
                "theta", dist.Dirichlet(jnp.ones(hyperparams["num_topics"]))
            )

            logits = decoder(theta, is_training, rngs={"dropout": numpyro.prng_key()})

            total_count = batch_docs.sum(-1)
            numpyro.sample(
                "obs", dist.Multinomial(total_count, logits=logits), obs=batch_docs
            )


    def guide(docs, hyperparams, is_training=False):
        encoder = flax_module(
            "encoder",
            FlaxEncoder(
                hyperparams["vocab_size"],
                hyperparams["num_topics"],
                hyperparams["hidden"],
                hyperparams["dropout_rate"],
            ),
            input_shape=(1, hyperparams["vocab_size"]),
            # ensure PRNG key is made available to dropout layers
            apply_rng=["dropout"],
            # indicate mutable state due to BatchNorm layers
            mutable=["batch_stats"],
            # to ensure proper initialisation of BatchNorm we must
            # initialise with is_training=True
            is_training=True,
        )

        with numpyro.plate(
            "documents", docs.shape[0], subsample_size=hyperparams["batch_size"]
        ):
            batch_docs = numpyro.subsample(docs, event_dim=1)

            concentration = encoder(
                batch_docs, is_training, rngs={"dropout": numpyro.prng_key()}
            )

            numpyro.sample("theta", dist.Dirichlet(concentration))


    def load_data():
        news = fetch_20newsgroups(subset="all")
        vectorizer = CountVectorizer(max_df=0.5, min_df=20, stop_words="english")
        docs = jnp.array(vectorizer.fit_transform(news["data"]).toarray())

        vocab = pd.DataFrame(columns=["word", "index"])
        vocab["word"] = vectorizer.get_feature_names_out()
        vocab["index"] = vocab.index

        return docs, vocab


    def run_inference(docs, args):
        rng_key = random.key(0)
        docs = device_put(docs)

        hyperparams = dict(
            vocab_size=docs.shape[1],
            num_topics=args.num_topics,
            hidden=args.hidden,
            dropout_rate=args.dropout_rate,
            batch_size=args.batch_size,
        )

        optimizer = numpyro.optim.Adam(args.learning_rate)
        svi = SVI(model, guide, optimizer, loss=TraceMeanField_ELBO())

        return svi.run(
            rng_key,
            args.num_steps,
            docs,
            hyperparams,
            is_training=True,
            progress_bar=not args.disable_progbar,
        )


    def plot_word_cloud(b, ax, vocab, n):
        indices = jnp.argsort(b)[::-1]
        top20 = indices[:20]
        df = pd.DataFrame(top20, columns=["index"])
        words = pd.merge(df, vocab[["index", "word"]], how="left", on="index")[
            "word"
        ].values.tolist()
        sizes = b[top20].tolist()
        freqs = {words[i]: sizes[i] for i in range(len(words))}
        wc = WordCloud(background_color="white", width=800, height=500)
        wc = wc.generate_from_frequencies(freqs)
        ax.set_title(f"Topic {n + 1}")
        ax.imshow(wc, interpolation="bilinear")
        ax.axis("off")


    def main(args):
        docs, vocab = load_data()
        print(f"Dictionary size: {len(vocab)}")
        print(f"Corpus size: {docs.shape}")

        svi_result = run_inference(docs, args)

        beta = svi_result.params["decoder$params"]["Dense_0"]["kernel"]
        beta = jax.nn.softmax(beta)

        # the number of plots depends on the chosen number of topics.
        # add 2 to num topics to ensure we create a row for any remainder after division
        nrows = (args.num_topics + 2) // 3
        fig, axs = plt.subplots(nrows, 3, figsize=(14, 3 + 3 * nrows))
        axs = axs.flatten()

        for n in range(beta.shape[0]):
            plot_word_cloud(beta[n], axs[n], vocab, n)

        # hide any unused axes
        for i in range(n, len(axs)):
            axs[i].axis("off")

        fig.savefig("wordclouds.png")


    if __name__ == "__main__":
        assert numpyro.__version__.startswith("0.21.0")
        parser = argparse.ArgumentParser(
            description="Probabilistic topic modelling with Flax"
        )
        parser.add_argument("-n", "--num-steps", nargs="?", default=30_000, type=int)
        parser.add_argument("-t", "--num-topics", nargs="?", default=12, type=int)
        parser.add_argument("--batch-size", nargs="?", default=32, type=int)
        parser.add_argument("--learning-rate", nargs="?", default=1e-3, type=float)
        parser.add_argument("--hidden", nargs="?", default=200, type=int)
        parser.add_argument("--dropout-rate", nargs="?", default=0.2, type=float)
        parser.add_argument(
            "-dp",
            "--disable-progbar",
            action="store_true",
            default=False,
            help="Whether to disable progress bar",
        )
        parser.add_argument(
            "--device", default="cpu", type=str, help='use "cpu", "gpu" or "tpu".'
        )
        args = parser.parse_args()

        numpyro.set_platform(args.device)
        main(args)


.. _sphx_glr_download_examples_prodlda.py:

.. only:: html

  .. container:: sphx-glr-footer sphx-glr-footer-example

    .. container:: sphx-glr-download sphx-glr-download-jupyter

      :download:`Download Jupyter notebook: prodlda.ipynb <prodlda.ipynb>`

    .. container:: sphx-glr-download sphx-glr-download-python

      :download:`Download Python source code: prodlda.py <prodlda.py>`

    .. container:: sphx-glr-download sphx-glr-download-zip

      :download:`Download zipped: prodlda.zip <prodlda.zip>`


.. only:: html

 .. rst-class:: sphx-glr-signature

    `Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_
