Linear Discriminant Analysis in Machine Learning
Last Updated :
18 May, 2025
When working with high-dimensional datasets it is important to apply dimensionality reduction techniques to make data exploration and modeling more efficient. Linear Discriminant Analysis (LDA) also known as Normal Discriminant Analysis is supervised classification problem that helps separate two or more classes by converting higher-dimensional data space into a lower-dimensional space. It is used to identify a linear combination of features that best separates classes within a dataset.
2 Classes overlappingFor example we have two classes that need to be separated efficiently. Each class may have multiple features and using a single feature to classify them may result in overlapping. To solve this LDA is used as it uses multiple features to improve classification accuracy. LDA works by some assumptions and we are required to understand them so that we have a better understanding of its working.
Key Assumptions of LDA
For LDA to perform effectively, certain assumptions are made:
If these assumptions are met LDA can produce very good results. For example when data points belonging to two classes are plotted if they are not linearly separable LDA will attempt to find a projection that maximizes class separability.
Linearly Separable DatasetImage shows an example where the classes (black and green circles) are not linearly separable. LDA attempts to separate them using red dashed line. It uses both axes (X and Y) to generate a new axis in such a way that it maximizes the distance between the means of the two classes while minimizing the variation within each class. This transforms the dataset into a space where the classes are better separated. After transforming the data points along a new axis LDA maximizes the class separation. This new axis allows for clearer classification by projecting the data along a line that enhance the distance between the means of the two classes.
The perpendicular distance between the line and pointsPerpendicular distance between the decision boundary and the data points helps us to visualize how LDA works by reducing class variation and increasing separability. After generating this new axis using the above-mentioned criteria all the data points of the classes are plotted on this new axis and are shown in the figure given below.
LDAIt shows how LDA creates a new axis to project the data and separate the two classes effectively along a linear path. But it fails when the mean of the distributions are shared as it becomes impossible for LDA to find a new axis that makes both classes linearly separable. In such cases we use non-linear discriminant analysis.
How does LDA work
LDA works by finding directions in the feature space that best separate the classes. It does this by maximizing the difference between the class means while minimizing the spread within each class.
Let’s assume we have two classes with d
-dimensional samples such as x_1, x_2, ... x_n where:
- n_1 samples belong to class c_1
- n_2​ samples belong to class c_2​.
If x_i​ represents a data point its projection onto the line represented by the unit vector v is v^T x_i​. Let the means of class c_1 and class c_2​ before projection be μ1​ and μ2 respectively. After projection the new means are \hat{\mu}_1 = v^T \mu_1and \hat{\mu}_2 = v^T \mu_2​.
Our aim to normalize the difference |\hat{\mu}_1 - \hat{\mu}_2|to maximize the class separation. The scatter for samples of class c_1​ is calculated as:
s_1^2 = \sum_{x_i \in c_1} (x_i - \mu_1)^2
Similarly for class c_2​:
s_2^2 = \sum_{x_i \in c_2} (x_i - \mu_2)^2
The goal is to maximize the ratio of the between-class scatter to the within-class scatter, which leads us to the following criteria:
J(v) = \frac{|\hat{\mu}_1 - \hat{\mu}_2|}{s_1^2 + s_2^2}
For the best separation we calculate the eigenvector corresponding to the highest eigenvalue of the scatter matrices s_w^{-1} s_b.
Extensions to LDA
- Quadratic Discriminant Analysis (QDA): Each class uses its own estimate of variance (or covariance) allowing it to handle more complex relationships.
- Flexible Discriminant Analysis (FDA): Uses non-linear combinations of inputs such as splines to handle non-linear separability.
- Regularized Discriminant Analysis (RDA): Introduces regularization into the covariance estimate to prevent overfitting.
Implementation of LDA using Python
In this implementation we will perform linear discriminant analysis using Scikit-learn library on the Iris dataset.
- StandardScaler(): Standardizes the features to ensure they have a mean of 0 and a standard deviation of 1 removing the influence of different scales.
- fit_transform(): Standardizes the feature data by applying the transformation learned from the training data ensuring each feature contributes equally.
- LabelEncoder(): Converts categorical labels into numerical values that machine learning models can process.
- fit_transform() on y: Transforms the target labels into numerical values for use in classification models.
- LinearDiscriminantAnalysis(): Reduces the dimensionality of the data by projecting it into a lower-dimensional space while maximizing the separation between classes.
- transform() on X_test: Applies the learned LDA transformation to the test data to maintain consistency with the training data.
Python
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.datasets import load_iris
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.model_selection import train_test_split
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score, confusion_matrix
iris = load_iris()
dataset = pd.DataFrame(columns=iris.feature_names,
data=iris.data)
dataset['target'] = iris.target
X = dataset.iloc[:, 0:4].values
y = dataset.iloc[:, 4].values
sc = StandardScaler()
X = sc.fit_transform(X)
le = LabelEncoder()
y = le.fit_transform(y)
X_train, X_test,\
y_train, y_test = train_test_split(X, y,
test_size=0.2)
lda = LinearDiscriminantAnalysis(n_components=2)
X_train = lda.fit_transform(X_train, y_train)
X_test = lda.transform(X_test)
plt.scatter(
X_train[:, 0], X_train[:, 1],
c=y_train,
cmap='rainbow',
alpha=0.7, edgecolors='b'
)
classifier = RandomForestClassifier(max_depth=2,
random_state=0)
classifier.fit(X_train, y_train)
y_pred = classifier.predict(X_test)
print('Accuracy : ' + str(accuracy_score(y_test, y_pred)))
conf_m = confusion_matrix(y_test, y_pred)
print(conf_m)
Output:
Accuracy : 0.9
[[ 8 0 0]
[ 0 8 2]
[ 0 1 11]]
Scatter plot of the iris data mapped into 2DThis scatter plot shows three distinct groups of data points, represented by different colors. The group on the right (dark blue) is clearly separated from the others indicate it's very different. The other two groups (red and light blue) are positioned closer together with some overlap and suggest they are more similar and harder to separate.
Advantages of LDA
- Simple and computationally efficient.
- Works well even when the number of features is much larger than the number of training samples.
- Can handle multicollinearity.
Disadvantages of LDA
- Assumes Gaussian distribution of data which may not always be the case.
- Assumes equal covariance matrices for different classes which may not hold in all datasets.
- Assumes linear separability which is not always true.
- May not always perform well in high-dimensional feature spaces.
Applications of LDA
- Face Recognition: It is used to reduce the high-dimensional feature space of pixel values in face recognition applications helping to identify faces more efficiently.
- Medical Diagnosis: It classifies disease severity in mild, moderate or severe based on patient parameters helping in decision-making for treatment.
- Customer Identification: It can help identify customer segments most likely to purchase a specific product based on survey data.
Similar Reads
Regularized Discriminant Analysis
Regularized Discriminant analysis Linear Discriminant analysis and QDA work straightforwardly for cases where a number of observations is far greater than the number of predictors n>p. In these situations, it offers very advantages such as ease to apply (Since we don't have to calculate the covar
3 min read
Linear and Quadratic Discriminant Analysis using Sklearn
Linear Discriminant Analysis (LDA) and Quadratic Discriminant Analysis (QDA) are two well-known classification methods that are used in machine learning to find patterns and put things into groups. They are especially helpful when you have labeled data and want to classify new observations notes int
5 min read
Normal and Shrinkage Linear Discriminant Analysis for Classification in Scikit Learn
In this article, we will try to understand the difference between Normal and Shrinkage Linear Discriminant Analysis for Classification. We will try to implement the same using sci-kit learn library in Python. But first, let's try to understand what is LDA. What is Linear discriminant analysis (LDA)?
4 min read
Gaussian Discriminant Analysis
Gaussian Discriminant Analysis (GDA) is a supervised learning algorithm used for classification tasks in machine learning. It is a variant of the Linear Discriminant Analysis (LDA) algorithm that relaxes the assumption that the covariance matrices of the different classes are equal. GDA works by ass
7 min read
Quadratic Discriminant Analysis
Linear Discriminant Analysis Now, Let's consider a classification problem represented by a Bayes Probability distribution P(Y=k | X=x), LDA does it differently by trying to model the distribution of X given the predictors class (I.e. the value of Y) P(X=x| Y=k): P(Y=k | X=x) = \frac{P(X=x | Y=k) P(Y
4 min read
Locally Linear Embedding in machine learning
LLE (Locally Linear Embedding) is a technique used to reduce the number of dimensions in a dataset without losing the important shape or structure of the data. It is an unsupervised method meaning it works without needing labeled data. LLE operates in several key steps:First LLE finds the nearest ne
5 min read
Curse of Dimensionality in Machine Learning
Curse of Dimensionality in Machine Learning arises when working with high-dimensional data, leading to increased computational complexity, overfitting, and spurious correlations. Techniques like dimensionality reduction, feature selection, and careful model design are essential for mitigating its ef
5 min read
Linear Algebra Operations For Machine Learning
Linear algebra is essential for many machine learning algorithms and techniques. It helps in manipulating and processing data, which is often represented as vectors and matrices. These mathematical tools make computations faster and reveal patterns within the data.It simplifies complex tasks like da
15+ min read
Interpolation in Machine Learning
In machine learning, interpolation refers to the process of estimating unknown values that fall between known data points. This can be useful in various scenarios, such as filling in missing values in a dataset or generating new data points to smooth out a curve. In this article, we are going to exp
7 min read
Linear Regression in Machine learning
Linear regression is a type of supervised machine-learning algorithm that learns from the labelled datasets and maps the data points with most optimized linear functions which can be used for prediction on new datasets. It assumes that there is a linear relationship between the input and output, mea
15+ min read