
.. DO NOT EDIT.
.. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY.
.. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE:
.. "auto_examples/multioutput/plot_classifier_chain_yeast.py"
.. LINE NUMBERS ARE GIVEN BELOW.

.. only:: html

    .. note::
        :class: sphx-glr-download-link-note

        :ref:`Go to the end <sphx_glr_download_auto_examples_multioutput_plot_classifier_chain_yeast.py>`
        to download the full example code or to run this example in your browser via JupyterLite or Binder.

.. rst-class:: sphx-glr-example-title

.. _sphx_glr_auto_examples_multioutput_plot_classifier_chain_yeast.py:


==================================================
Multilabel classification using a classifier chain
==================================================
This example shows how to use :class:`~sklearn.multioutput.ClassifierChain` to solve
a multilabel classification problem.

The most naive strategy to solve such a task is to independently train a binary
classifier on each label (i.e. each column of the target variable). At prediction
time, the ensemble of binary classifiers is used to assemble multitask prediction.

This strategy does not allow to model relationship between different tasks. The
:class:`~sklearn.multioutput.ClassifierChain` is the meta-estimator (i.e. an estimator
taking an inner estimator) that implements a more advanced strategy. The ensemble
of binary classifiers are used as a chain where the prediction of a classifier in the
chain is used as a feature for training the next classifier on a new label. Therefore,
these additional features allow each chain to exploit correlations among labels.

The :ref:`Jaccard similarity <jaccard_similarity_score>` score for chain tends to be
greater than that of the set independent base models.

.. GENERATED FROM PYTHON SOURCE LINES 22-26

.. code-block:: Python


    # Authors: The scikit-learn developers
    # SPDX-License-Identifier: BSD-3-Clause








.. GENERATED FROM PYTHON SOURCE LINES 27-36

Loading a dataset
-----------------
For this example, we use the `yeast
<https://www.openml.org/d/40597>`_ dataset which contains
2,417 datapoints each with 103 features and 14 possible labels. Each
data point has at least one label. As a baseline we first train a logistic
regression classifier for each of the 14 labels. To evaluate the performance of
these classifiers we predict on a held-out test set and calculate the
Jaccard similarity for each sample.

.. GENERATED FROM PYTHON SOURCE LINES 36-48

.. code-block:: Python


    import matplotlib.pyplot as plt
    import numpy as np

    from sklearn.datasets import fetch_openml
    from sklearn.model_selection import train_test_split

    # Load a multi-label dataset from https://www.openml.org/d/40597
    X, Y = fetch_openml("yeast", version=4, return_X_y=True)
    Y = Y == "TRUE"
    X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=0.2, random_state=0)








.. GENERATED FROM PYTHON SOURCE LINES 49-61

Fit models
----------
We fit :class:`~sklearn.linear_model.LogisticRegression` wrapped by
:class:`~sklearn.multiclass.OneVsRestClassifier` and ensemble of multiple
:class:`~sklearn.multioutput.ClassifierChain`.

LogisticRegression wrapped by OneVsRestClassifier
**************************************************
Since by default :class:`~sklearn.linear_model.LogisticRegression` can't
handle data with multiple targets, we need to use
:class:`~sklearn.multiclass.OneVsRestClassifier`.
After fitting the model we calculate Jaccard similarity.

.. GENERATED FROM PYTHON SOURCE LINES 61-72

.. code-block:: Python


    from sklearn.linear_model import LogisticRegression
    from sklearn.metrics import jaccard_score
    from sklearn.multiclass import OneVsRestClassifier

    base_lr = LogisticRegression()
    ovr = OneVsRestClassifier(base_lr)
    ovr.fit(X_train, Y_train)
    Y_pred_ovr = ovr.predict(X_test)
    ovr_jaccard_score = jaccard_score(Y_test, Y_pred_ovr, average="samples")








.. GENERATED FROM PYTHON SOURCE LINES 73-84

Chain of binary classifiers
***************************
Because the models in each chain are arranged randomly there is significant
variation in performance among the chains. Presumably there is an optimal
ordering of the classes in a chain that will yield the best performance.
However, we do not know that ordering a priori. Instead, we can build a
voting ensemble of classifier chains by averaging the binary predictions of
the chains and apply a threshold of 0.5. The Jaccard similarity score of the
ensemble is greater than that of the independent models and tends to exceed
the score of each chain in the ensemble (although this is not guaranteed
with randomly ordered chains).

.. GENERATED FROM PYTHON SOURCE LINES 84-102

.. code-block:: Python


    from sklearn.multioutput import ClassifierChain

    chains = [ClassifierChain(base_lr, order="random", random_state=i) for i in range(10)]
    for chain in chains:
        chain.fit(X_train, Y_train)

    Y_pred_chains = np.array([chain.predict_proba(X_test) for chain in chains])
    chain_jaccard_scores = [
        jaccard_score(Y_test, Y_pred_chain >= 0.5, average="samples")
        for Y_pred_chain in Y_pred_chains
    ]

    Y_pred_ensemble = Y_pred_chains.mean(axis=0)
    ensemble_jaccard_score = jaccard_score(
        Y_test, Y_pred_ensemble >= 0.5, average="samples"
    )








.. GENERATED FROM PYTHON SOURCE LINES 103-108

Plot results
------------
Plot the Jaccard similarity scores for the independent model, each of the
chains, and the ensemble (note that the vertical axis on this plot does
not begin at 0).

.. GENERATED FROM PYTHON SOURCE LINES 108-140

.. code-block:: Python


    model_scores = [ovr_jaccard_score] + chain_jaccard_scores + [ensemble_jaccard_score]

    model_names = (
        "Independent",
        "Chain 1",
        "Chain 2",
        "Chain 3",
        "Chain 4",
        "Chain 5",
        "Chain 6",
        "Chain 7",
        "Chain 8",
        "Chain 9",
        "Chain 10",
        "Ensemble",
    )

    x_pos = np.arange(len(model_names))

    fig, ax = plt.subplots(figsize=(7, 4))
    ax.grid(True)
    ax.set_title("Classifier Chain Ensemble Performance Comparison")
    ax.set_xticks(x_pos)
    ax.set_xticklabels(model_names, rotation="vertical")
    ax.set_ylabel("Jaccard Similarity Score")
    ax.set_ylim([min(model_scores) * 0.9, max(model_scores) * 1.1])
    colors = ["r"] + ["b"] * len(chain_jaccard_scores) + ["g"]
    ax.bar(x_pos, model_scores, alpha=0.5, color=colors)
    plt.tight_layout()
    plt.show()




.. image-sg:: /auto_examples/multioutput/images/sphx_glr_plot_classifier_chain_yeast_001.png
   :alt: Classifier Chain Ensemble Performance Comparison
   :srcset: /auto_examples/multioutput/images/sphx_glr_plot_classifier_chain_yeast_001.png
   :class: sphx-glr-single-img





.. GENERATED FROM PYTHON SOURCE LINES 141-154

Results interpretation
----------------------
There are three main takeaways from this plot:

- Independent model wrapped by :class:`~sklearn.multiclass.OneVsRestClassifier`
  performs worse than the ensemble of classifier chains and some of individual chains.
  This is caused by the fact that the logistic regression doesn't model relationship
  between the labels.
- :class:`~sklearn.multioutput.ClassifierChain` takes advantage of correlation
  among labels but due to random nature of labels ordering, it could yield worse
  result than an independent model.
- An ensemble of chains performs better because it not only captures relationship
  between labels but also does not make strong assumptions about their correct order.


.. rst-class:: sphx-glr-timing

   **Total running time of the script:** (0 minutes 1.792 seconds)


.. _sphx_glr_download_auto_examples_multioutput_plot_classifier_chain_yeast.py:

.. only:: html

  .. container:: sphx-glr-footer sphx-glr-footer-example

    .. container:: binder-badge

      .. image:: images/binder_badge_logo.svg
        :target: https://mybinder.org/v2/gh/scikit-learn/scikit-learn/1.8.X?urlpath=lab/tree/notebooks/auto_examples/multioutput/plot_classifier_chain_yeast.ipynb
        :alt: Launch binder
        :width: 150 px

    .. container:: lite-badge

      .. image:: images/jupyterlite_badge_logo.svg
        :target: ../../lite/lab/index.html?path=auto_examples/multioutput/plot_classifier_chain_yeast.ipynb
        :alt: Launch JupyterLite
        :width: 150 px

    .. container:: sphx-glr-download sphx-glr-download-jupyter

      :download:`Download Jupyter notebook: plot_classifier_chain_yeast.ipynb <plot_classifier_chain_yeast.ipynb>`

    .. container:: sphx-glr-download sphx-glr-download-python

      :download:`Download Python source code: plot_classifier_chain_yeast.py <plot_classifier_chain_yeast.py>`

    .. container:: sphx-glr-download sphx-glr-download-zip

      :download:`Download zipped: plot_classifier_chain_yeast.zip <plot_classifier_chain_yeast.zip>`


.. include:: plot_classifier_chain_yeast.recommendations


.. only:: html

 .. rst-class:: sphx-glr-signature

    `Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_
