
.. DO NOT EDIT.
.. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY.
.. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE:
.. "auto_examples/semi_supervised/plot_label_propagation_digits.py"
.. LINE NUMBERS ARE GIVEN BELOW.

.. only:: html

    .. note::
        :class: sphx-glr-download-link-note

        :ref:`Go to the end <sphx_glr_download_auto_examples_semi_supervised_plot_label_propagation_digits.py>`
        to download the full example code or to run this example in your browser via JupyterLite or Binder.

.. rst-class:: sphx-glr-example-title

.. _sphx_glr_auto_examples_semi_supervised_plot_label_propagation_digits.py:


===================================================
Label Propagation digits: Demonstrating performance
===================================================

This example demonstrates the power of semisupervised learning by
training a Label Spreading model to classify handwritten digits
with sets of very few labels.

The handwritten digit dataset has 1797 total points. The model will
be trained using all points, but only 30 will be labeled. Results
in the form of a confusion matrix and a series of metrics over each
class will be very good.

At the end, the top 10 most uncertain predictions will be shown.

.. GENERATED FROM PYTHON SOURCE LINES 18-22

.. code-block:: Python


    # Authors: The scikit-learn developers
    # SPDX-License-Identifier: BSD-3-Clause








.. GENERATED FROM PYTHON SOURCE LINES 23-27

Data generation
---------------

We use the digits dataset. We only use a subset of randomly selected samples.

.. GENERATED FROM PYTHON SOURCE LINES 27-36

.. code-block:: Python

    import numpy as np

    from sklearn import datasets

    digits = datasets.load_digits()
    rng = np.random.RandomState(2)
    indices = np.arange(len(digits.data))
    rng.shuffle(indices)








.. GENERATED FROM PYTHON SOURCE LINES 37-40

We selected 340 samples of which only 40 will be associated with a known label.
Therefore, we store the indices of the 300 other samples for which we are not
supposed to know their labels.

.. GENERATED FROM PYTHON SOURCE LINES 41-52

.. code-block:: Python

    X = digits.data[indices[:340]]
    y = digits.target[indices[:340]]
    images = digits.images[indices[:340]]

    n_total_samples = len(y)
    n_labeled_points = 40

    indices = np.arange(n_total_samples)

    unlabeled_set = indices[n_labeled_points:]








.. GENERATED FROM PYTHON SOURCE LINES 53-54

Shuffle everything around

.. GENERATED FROM PYTHON SOURCE LINES 54-57

.. code-block:: Python

    y_train = np.copy(y)
    y_train[unlabeled_set] = -1








.. GENERATED FROM PYTHON SOURCE LINES 58-63

Semi-supervised learning
------------------------

We fit a :class:`~sklearn.semi_supervised.LabelSpreading` and use it to predict
the unknown labels.

.. GENERATED FROM PYTHON SOURCE LINES 63-76

.. code-block:: Python

    from sklearn.metrics import classification_report
    from sklearn.semi_supervised import LabelSpreading

    lp_model = LabelSpreading(gamma=0.25, max_iter=20)
    lp_model.fit(X, y_train)
    predicted_labels = lp_model.transduction_[unlabeled_set]
    true_labels = y[unlabeled_set]

    print(
        "Label Spreading model: %d labeled & %d unlabeled points (%d total)"
        % (n_labeled_points, n_total_samples - n_labeled_points, n_total_samples)
    )





.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    Label Spreading model: 40 labeled & 300 unlabeled points (340 total)




.. GENERATED FROM PYTHON SOURCE LINES 77-78

Classification report

.. GENERATED FROM PYTHON SOURCE LINES 78-80

.. code-block:: Python

    print(classification_report(true_labels, predicted_labels))





.. rst-class:: sphx-glr-script-out

 .. code-block:: none

                  precision    recall  f1-score   support

               0       1.00      1.00      1.00        27
               1       0.82      1.00      0.90        37
               2       1.00      0.86      0.92        28
               3       1.00      0.80      0.89        35
               4       0.92      1.00      0.96        24
               5       0.74      0.94      0.83        34
               6       0.89      0.96      0.92        25
               7       0.94      0.89      0.91        35
               8       1.00      0.68      0.81        31
               9       0.81      0.88      0.84        24

        accuracy                           0.90       300
       macro avg       0.91      0.90      0.90       300
    weighted avg       0.91      0.90      0.90       300





.. GENERATED FROM PYTHON SOURCE LINES 81-82

Confusion matrix

.. GENERATED FROM PYTHON SOURCE LINES 82-88

.. code-block:: Python

    from sklearn.metrics import ConfusionMatrixDisplay

    ConfusionMatrixDisplay.from_predictions(
        true_labels, predicted_labels, labels=lp_model.classes_
    )




.. image-sg:: /auto_examples/semi_supervised/images/sphx_glr_plot_label_propagation_digits_001.png
   :alt: plot label propagation digits
   :srcset: /auto_examples/semi_supervised/images/sphx_glr_plot_label_propagation_digits_001.png
   :class: sphx-glr-single-img


.. rst-class:: sphx-glr-script-out

 .. code-block:: none


    <sklearn.metrics._plot.confusion_matrix.ConfusionMatrixDisplay object at 0x7fe89ca28890>



.. GENERATED FROM PYTHON SOURCE LINES 89-93

Plot the most uncertain predictions
-----------------------------------

Here, we will pick and show the 10 most uncertain predictions.

.. GENERATED FROM PYTHON SOURCE LINES 93-97

.. code-block:: Python

    from scipy import stats

    pred_entropies = stats.distributions.entropy(lp_model.label_distributions_.T)








.. GENERATED FROM PYTHON SOURCE LINES 98-99

Pick the top 10 most uncertain labels

.. GENERATED FROM PYTHON SOURCE LINES 99-101

.. code-block:: Python

    uncertainty_index = np.argsort(pred_entropies)[-10:]








.. GENERATED FROM PYTHON SOURCE LINES 102-103

Plot

.. GENERATED FROM PYTHON SOURCE LINES 103-119

.. code-block:: Python

    import matplotlib.pyplot as plt

    f = plt.figure(figsize=(7, 5))
    for index, image_index in enumerate(uncertainty_index):
        image = images[image_index]

        sub = f.add_subplot(2, 5, index + 1)
        sub.imshow(image, cmap=plt.cm.gray_r)
        plt.xticks([])
        plt.yticks([])
        sub.set_title(
            "predict: %i\ntrue: %i" % (lp_model.transduction_[image_index], y[image_index])
        )

    f.suptitle("Learning with small amount of labeled data")
    plt.show()



.. image-sg:: /auto_examples/semi_supervised/images/sphx_glr_plot_label_propagation_digits_002.png
   :alt: Learning with small amount of labeled data, predict: 1 true: 2, predict: 2 true: 2, predict: 8 true: 8, predict: 1 true: 8, predict: 1 true: 8, predict: 1 true: 8, predict: 3 true: 3, predict: 8 true: 8, predict: 2 true: 2, predict: 7 true: 2
   :srcset: /auto_examples/semi_supervised/images/sphx_glr_plot_label_propagation_digits_002.png
   :class: sphx-glr-single-img






.. rst-class:: sphx-glr-timing

   **Total running time of the script:** (0 minutes 0.258 seconds)


.. _sphx_glr_download_auto_examples_semi_supervised_plot_label_propagation_digits.py:

.. only:: html

  .. container:: sphx-glr-footer sphx-glr-footer-example

    .. container:: binder-badge

      .. image:: images/binder_badge_logo.svg
        :target: https://mybinder.org/v2/gh/scikit-learn/scikit-learn/1.8.X?urlpath=lab/tree/notebooks/auto_examples/semi_supervised/plot_label_propagation_digits.ipynb
        :alt: Launch binder
        :width: 150 px

    .. container:: lite-badge

      .. image:: images/jupyterlite_badge_logo.svg
        :target: ../../lite/lab/index.html?path=auto_examples/semi_supervised/plot_label_propagation_digits.ipynb
        :alt: Launch JupyterLite
        :width: 150 px

    .. container:: sphx-glr-download sphx-glr-download-jupyter

      :download:`Download Jupyter notebook: plot_label_propagation_digits.ipynb <plot_label_propagation_digits.ipynb>`

    .. container:: sphx-glr-download sphx-glr-download-python

      :download:`Download Python source code: plot_label_propagation_digits.py <plot_label_propagation_digits.py>`

    .. container:: sphx-glr-download sphx-glr-download-zip

      :download:`Download zipped: plot_label_propagation_digits.zip <plot_label_propagation_digits.zip>`


.. include:: plot_label_propagation_digits.recommendations


.. only:: html

 .. rst-class:: sphx-glr-signature

    `Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_
