Open In App

Linear Discriminant Analysis in Machine Learning

Last Updated : 18 May, 2025
Comments
Improve
Suggest changes
Like Article
Like
Report

When working with high-dimensional datasets it is important to apply dimensionality reduction techniques to make data exploration and modeling more efficient. Linear Discriminant Analysis (LDA) also known as Normal Discriminant Analysis is supervised classification problem that helps separate two or more classes by converting higher-dimensional data space into a lower-dimensional space. It is used to identify a linear combination of features that best separates classes within a dataset.

Linear Discriminant Analysis
2 Classes overlapping

For example we have two classes that need to be separated efficiently. Each class may have multiple features and using a single feature to classify them may result in overlapping. To solve this LDA is used as it uses multiple features to improve classification accuracy. LDA works by some assumptions and we are required to understand them so that we have a better understanding of its working.

Key Assumptions of LDA

For LDA to perform effectively, certain assumptions are made:

If these assumptions are met LDA can produce very good results. For example when data points belonging to two classes are plotted if they are not linearly separable LDA will attempt to find a projection that maximizes class separability.
 

Linearly Seperable Dataset
Linearly Separable Dataset

Image shows an example where the classes (black and green circles) are not linearly separable. LDA attempts to separate them using red dashed line. It uses both axes (X and Y) to generate a new axis in such a way that it maximizes the distance between the means of the two classes while minimizing the variation within each class. This transforms the dataset into a space where the classes are better separated. After transforming the data points along a new axis LDA maximizes the class separation. This new axis allows for clearer classification by projecting the data along a line that enhance the distance between the means of the two classes.

ldanewaxis
The perpendicular distance between the line and points

Perpendicular distance between the decision boundary and the data points helps us to visualize how LDA works by reducing class variation and increasing separability. After generating this new axis using the above-mentioned criteria all the data points of the classes are plotted on this new axis and are shown in the figure given below. 

1dlda
LDA

It shows how LDA creates a new axis to project the data and separate the two classes effectively along a linear path. But it fails when the mean of the distributions are shared as it becomes impossible for LDA to find a new axis that makes both classes linearly separable. In such cases we use non-linear discriminant analysis.

How does LDA work

LDA works by finding directions in the feature space that best separate the classes. It does this by maximizing the difference between the class means while minimizing the spread within each class.

Let’s assume we have two classes with d-dimensional samples such as x_1, x_2, ... x_n where:

  • n_1 samples belong to class c_1
  • n_2​ samples belong to class c_2​.

If x_i​ represents a data point its projection onto the line represented by the unit vector v is v^T x_i​. Let the means of class c_1 and class c_2​ before projection be μ1​ and μ2 respectively. After projection the new means are \hat{\mu}_1 = v^T \mu_1and \hat{\mu}_2 = v^T \mu_2​.

Our aim to normalize the difference |\hat{\mu}_1 - \hat{\mu}_2|to maximize the class separation. The scatter for samples of class c_1​ is calculated as:

s_1^2 = \sum_{x_i \in c_1} (x_i - \mu_1)^2

Similarly for class c_2​:

s_2^2 = \sum_{x_i \in c_2} (x_i - \mu_2)^2

The goal is to maximize the ratio of the between-class scatter to the within-class scatter, which leads us to the following criteria:

J(v) = \frac{|\hat{\mu}_1 - \hat{\mu}_2|}{s_1^2 + s_2^2}

For the best separation we calculate the eigenvector corresponding to the highest eigenvalue of the scatter matrices s_w^{-1} s_b.

Extensions to LDA

  1. Quadratic Discriminant Analysis (QDA): Each class uses its own estimate of variance (or covariance) allowing it to handle more complex relationships.
  2. Flexible Discriminant Analysis (FDA): Uses non-linear combinations of inputs such as splines to handle non-linear separability.
  3. Regularized Discriminant Analysis (RDA): Introduces regularization into the covariance estimate to prevent overfitting.

Implementation of LDA using Python

In this implementation we will perform linear discriminant analysis using Scikit-learn library on the Iris dataset.

  • StandardScaler(): Standardizes the features to ensure they have a mean of 0 and a standard deviation of 1 removing the influence of different scales.
  • fit_transform(): Standardizes the feature data by applying the transformation learned from the training data ensuring each feature contributes equally.
  • LabelEncoder(): Converts categorical labels into numerical values that machine learning models can process.
  • fit_transform() on y: Transforms the target labels into numerical values for use in classification models.
  • LinearDiscriminantAnalysis(): Reduces the dimensionality of the data by projecting it into a lower-dimensional space while maximizing the separation between classes.
  • transform() on X_test: Applies the learned LDA transformation to the test data to maintain consistency with the training data.
Python
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.datasets import load_iris
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.model_selection import train_test_split
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score, confusion_matrix

iris = load_iris()
dataset = pd.DataFrame(columns=iris.feature_names,
                       data=iris.data)
dataset['target'] = iris.target

X = dataset.iloc[:, 0:4].values
y = dataset.iloc[:, 4].values

sc = StandardScaler()
X = sc.fit_transform(X)
le = LabelEncoder()
y = le.fit_transform(y)
X_train, X_test,\
    y_train, y_test = train_test_split(X, y,
                                       test_size=0.2)

lda = LinearDiscriminantAnalysis(n_components=2)
X_train = lda.fit_transform(X_train, y_train)
X_test = lda.transform(X_test)

plt.scatter(
    X_train[:, 0], X_train[:, 1],
    c=y_train,
    cmap='rainbow',
    alpha=0.7, edgecolors='b'
)

classifier = RandomForestClassifier(max_depth=2,
                                    random_state=0)
classifier.fit(X_train, y_train)
y_pred = classifier.predict(X_test)

print('Accuracy : ' + str(accuracy_score(y_test, y_pred)))
conf_m = confusion_matrix(y_test, y_pred)
print(conf_m)

Output:

Accuracy : 0.9
[[ 8 0 0]
[ 0 8 2]
[ 0 1 11]]

Scatter plot of the iris data mapped into 2D
Scatter plot of the iris data mapped into 2D

This scatter plot shows three distinct groups of data points, represented by different colors. The group on the right (dark blue) is clearly separated from the others indicate it's very different. The other two groups (red and light blue) are positioned closer together with some overlap and suggest they are more similar and harder to separate.

Advantages of LDA

  • Simple and computationally efficient.
  • Works well even when the number of features is much larger than the number of training samples.
  • Can handle multicollinearity.

Disadvantages of LDA

  • Assumes Gaussian distribution of data which may not always be the case.
  • Assumes equal covariance matrices for different classes which may not hold in all datasets.
  • Assumes linear separability which is not always true.
  • May not always perform well in high-dimensional feature spaces.

Applications of LDA 

  1. Face Recognition: It is used to reduce the high-dimensional feature space of pixel values in face recognition applications helping to identify faces more efficiently.
  2. Medical Diagnosis: It classifies disease severity in mild, moderate or severe based on patient parameters helping in decision-making for treatment.
  3. Customer Identification: It can help identify customer segments most likely to purchase a specific product based on survey data.

Next Article

Similar Reads