
.. DO NOT EDIT.
.. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY.
.. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE:
.. "auto_examples/neighbors/plot_nca_illustration.py"
.. LINE NUMBERS ARE GIVEN BELOW.

.. only:: html

    .. note::
        :class: sphx-glr-download-link-note

        :ref:`Go to the end <sphx_glr_download_auto_examples_neighbors_plot_nca_illustration.py>`
        to download the full example code or to run this example in your browser via JupyterLite or Binder.

.. rst-class:: sphx-glr-example-title

.. _sphx_glr_auto_examples_neighbors_plot_nca_illustration.py:


=============================================
Neighborhood Components Analysis Illustration
=============================================

This example illustrates a learned distance metric that maximizes
the nearest neighbors classification accuracy. It provides a visual
representation of this metric compared to the original point
space. Please refer to the :ref:`User Guide <nca>` for more information.

.. GENERATED FROM PYTHON SOURCE LINES 12-24

.. code-block:: Python


    # Authors: The scikit-learn developers
    # SPDX-License-Identifier: BSD-3-Clause

    import matplotlib.pyplot as plt
    import numpy as np
    from matplotlib import cm
    from scipy.special import logsumexp

    from sklearn.datasets import make_classification
    from sklearn.neighbors import NeighborhoodComponentsAnalysis








.. GENERATED FROM PYTHON SOURCE LINES 25-31

Original points
---------------
First we create a data set of 9 samples from 3 classes, and plot the points
in the original space. For this example, we focus on the classification of
point no. 3. The thickness of a link between point no. 3 and another point
is proportional to their distance.

.. GENERATED FROM PYTHON SOURCE LINES 31-79

.. code-block:: Python


    X, y = make_classification(
        n_samples=9,
        n_features=2,
        n_informative=2,
        n_redundant=0,
        n_classes=3,
        n_clusters_per_class=1,
        class_sep=1.0,
        random_state=0,
    )

    plt.figure(1)
    ax = plt.gca()
    for i in range(X.shape[0]):
        ax.text(X[i, 0], X[i, 1], str(i), va="center", ha="center")
        ax.scatter(X[i, 0], X[i, 1], s=300, c=cm.Set1(y[[i]]), alpha=0.4)

    ax.set_title("Original points")
    ax.axes.get_xaxis().set_visible(False)
    ax.axes.get_yaxis().set_visible(False)
    ax.axis("equal")  # so that boundaries are displayed correctly as circles


    def link_thickness_i(X, i):
        diff_embedded = X[i] - X
        dist_embedded = np.einsum("ij,ij->i", diff_embedded, diff_embedded)
        dist_embedded[i] = np.inf

        # compute exponentiated distances (use the log-sum-exp trick to
        # avoid numerical instabilities
        exp_dist_embedded = np.exp(-dist_embedded - logsumexp(-dist_embedded))
        return exp_dist_embedded


    def relate_point(X, i, ax):
        pt_i = X[i]
        for j, pt_j in enumerate(X):
            thickness = link_thickness_i(X, i)
            if i != j:
                line = ([pt_i[0], pt_j[0]], [pt_i[1], pt_j[1]])
                ax.plot(*line, c=cm.Set1(y[j]), linewidth=5 * thickness[j])


    i = 3
    relate_point(X, i, ax)
    plt.show()




.. image-sg:: /auto_examples/neighbors/images/sphx_glr_plot_nca_illustration_001.png
   :alt: Original points
   :srcset: /auto_examples/neighbors/images/sphx_glr_plot_nca_illustration_001.png
   :class: sphx-glr-single-img





.. GENERATED FROM PYTHON SOURCE LINES 80-85

Learning an embedding
---------------------
We use :class:`~sklearn.neighbors.NeighborhoodComponentsAnalysis` to learn an
embedding and plot the points after the transformation. We then take the
embedding and find the nearest neighbors.

.. GENERATED FROM PYTHON SOURCE LINES 85-103

.. code-block:: Python


    nca = NeighborhoodComponentsAnalysis(max_iter=30, random_state=0)
    nca = nca.fit(X, y)

    plt.figure(2)
    ax2 = plt.gca()
    X_embedded = nca.transform(X)
    relate_point(X_embedded, i, ax2)

    for i in range(len(X)):
        ax2.text(X_embedded[i, 0], X_embedded[i, 1], str(i), va="center", ha="center")
        ax2.scatter(X_embedded[i, 0], X_embedded[i, 1], s=300, c=cm.Set1(y[[i]]), alpha=0.4)

    ax2.set_title("NCA embedding")
    ax2.axes.get_xaxis().set_visible(False)
    ax2.axes.get_yaxis().set_visible(False)
    ax2.axis("equal")
    plt.show()



.. image-sg:: /auto_examples/neighbors/images/sphx_glr_plot_nca_illustration_002.png
   :alt: NCA embedding
   :srcset: /auto_examples/neighbors/images/sphx_glr_plot_nca_illustration_002.png
   :class: sphx-glr-single-img






.. rst-class:: sphx-glr-timing

   **Total running time of the script:** (0 minutes 0.124 seconds)


.. _sphx_glr_download_auto_examples_neighbors_plot_nca_illustration.py:

.. only:: html

  .. container:: sphx-glr-footer sphx-glr-footer-example

    .. container:: binder-badge

      .. image:: images/binder_badge_logo.svg
        :target: https://mybinder.org/v2/gh/scikit-learn/scikit-learn/1.8.X?urlpath=lab/tree/notebooks/auto_examples/neighbors/plot_nca_illustration.ipynb
        :alt: Launch binder
        :width: 150 px

    .. container:: lite-badge

      .. image:: images/jupyterlite_badge_logo.svg
        :target: ../../lite/lab/index.html?path=auto_examples/neighbors/plot_nca_illustration.ipynb
        :alt: Launch JupyterLite
        :width: 150 px

    .. container:: sphx-glr-download sphx-glr-download-jupyter

      :download:`Download Jupyter notebook: plot_nca_illustration.ipynb <plot_nca_illustration.ipynb>`

    .. container:: sphx-glr-download sphx-glr-download-python

      :download:`Download Python source code: plot_nca_illustration.py <plot_nca_illustration.py>`

    .. container:: sphx-glr-download sphx-glr-download-zip

      :download:`Download zipped: plot_nca_illustration.zip <plot_nca_illustration.zip>`


.. include:: plot_nca_illustration.recommendations


.. only:: html

 .. rst-class:: sphx-glr-signature

    `Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_
