Mean Shift Clustering using Sklearn

Last Updated : 5 Jan, 2026

Mean Shift clustering is a non-parametric, density-based clustering algorithm that discovers clusters by locating the modes i.e. peaks of the data density in feature space and shifting data points toward those high-density areas until convergence. It does not require specifying the number of clusters in advance it automatically detects the number of clusters and works well for irregularly shaped clusters.

Implementation

Let's see the implementation of mean shift clustering using sklearn:

Step 1: Load the Dataset

We will import the iris dataset from scikit learn library and print the dataset using pandas.

Python
from sklearn.datasets import load_iris
import pandas as pd

iris = load_iris()
iris_df = pd.DataFrame(iris.data, columns=iris.feature_names)
iris_df["species"] = iris.target
iris_df.head()

Output:

Screenshot-2025-12-05-125218
Result

Step 2: Select Features and Scale Them

  • X contains all feature columns (sepal length, sepal width, petal length, petal width).
  • Mean Shift uses distance-based calculations, so scaling avoids feature dominance.
  • StandardScaler standardizes each feature to mean = 0 and variance = 1.
  • X_scaled is the normalized data ready for clustering.
Python
from sklearn.preprocessing import StandardScaler
X = iris_df[iris.feature_names].values
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)

Step 3: Estimate Bandwidth Automatically

  • estimate_bandwidth inspects point distances to determine a suitable neighborhood radius.
  • quantile=0.2 means bandwidth is based on the 20th percentile of pairwise distances.
  • A smaller quantile → smaller bandwidth → more clusters.
  • Bandwidth is critical for how Mean Shift behaves, so automatic estimation helps.
  • The printed value shows what neighborhood size the algorithm will use.
Python
from sklearn.cluster import estimate_bandwidth
bandwidth = estimate_bandwidth(X_scaled, quantile=0.2, n_samples=len(X_scaled))
print("Estimated Bandwidth:", bandwidth)

Output:

Estimated Bandwidth: 1.207017869625092

Step 4: Apply Mean Shift Clustering

  • A MeanShift object is created with the computed bandwidth.
  • bin_seeding=True speeds up clustering by using a grid to initialize seeds.
  • fit(X_scaled) performs the iterative shifting toward density peaks.
  • labels contains the cluster ID assigned to each of the 150 iris samples.
  • cluster_centers_ contains the final positions of discovered modes.
  • The number of cluster centers indicates how many clusters Mean Shift detected automatically.
Python
from sklearn.cluster import MeanShift
ms = MeanShift(bandwidth=bandwidth, bin_seeding=True)
ms.fit(X_scaled)

labels = ms.labels_
centers_scaled = ms.cluster_centers_
print("Number of clusters found:", len(centers_scaled))

Output:

Number of clusters found: 3

Step 5: Add Cluster Labels to the DataFrame

  • A new column cluster is added to store Mean Shift's cluster assignments.
  • Printing the first few rows shows how species labels compare to cluster labels.
  • crosstab shows the distribution of real species across discovered clusters.
  • This helps evaluate how well the unsupervised algorithm separates the species.
Python
iris_df["cluster"] = labels
print(iris_df[["species", "cluster"]].head())
print("\nCluster vs Species:")
print(pd.crosstab(iris_df["cluster"], iris_df["species"]))

Output:

Screenshot-2025-12-05-125226
Result

Step 6: Visualize Clusters

Now we will visualize the clusters using matplotlib.

Python
import matplotlib.pyplot as plt
plt.figure(figsize=(7, 5))
plt.scatter(
    iris_df["petal length (cm)"],
    iris_df["petal width (cm)"],
    c=iris_df["cluster"],
    cmap="viridis",
    alpha=0.8
)
plt.xlabel("Petal Length (cm)")
plt.ylabel("Petal Width (cm)")
plt.title("Mean Shift Clustering on Iris Dataset")
plt.colorbar(label="Cluster ID")
plt.show()

Output:

MSC
Result

As we can see our model is able to do clustering using Mean Shift Clustering.

Application

  • Image segmentation: By grouping pixels into meaningful regions based on color or texture.
  • Object tracking: In videos, where Mean Shift follows the highest-density region corresponding to a moving object.
  • Dominant color extraction: From images for tasks like palette generation or background isolation.
  • Customer segmentation: When the number of customer groups is unknown and clusters may have irregular shapes.
  • Geospatial hotspot detection: Such as identifying dense areas of traffic, population or events.
Comment