How to split a Dataset into Train and Test Sets using Python
Last Updated :
18 Apr, 2025
One of the most important steps in preparing data for training a ML model is splitting the dataset into training and testing sets. This simply means dividing the data into two parts: one to train the machine learning model (training set), and another to evaluate how well it performs on unseen data (testing set). The training set is used to fit the model, and the statistics of the training set are known. The second set is called the test data set which is solely used for predictions.
We’ll see how to split a dataset into train and test sets using Python. We'll use scikit-learn
library to perform the split efficiently. Whether you're working with numerical data, text, or images, this is an essential part of any supervised machine learning workflow.
Installation:
The scikit-learn library can be installed using pip:-
Python
Alternatively, it can also be downloaded from here.
Dataset Splitting
Scikit-learn is one of the most widely used machine learning libraries in Python. It provides a range of tools for building models, pre-processing data, and evaluating performance. For splitting datasets, it provides a handy function called train_test_split()
within the model_selection
module, making it simple to divide your data into training and testing sets.
Syntax:
train_test_split(*arrays, test_size=None, train_size=None, random_state=None, shuffle=True, stratify=None)
Parameters:
*arrays
: The data you want to split. This can be in the form of lists, arrays, pandas DataFrames, or matrices.test_size
: A number between 0.0 and 1.0 that tells what portion of the data should go into the test set. For example, 0.2
means 20% of the data will be used for testing.train_size
: this is a number between 0.0 and 1.0 that tells what portion of the data should go into the training set. If not set, it’s automatically calculated based on the test_size
.random_state
: A number that makes sure the split is the same every time you run the code. It’s like setting a seed for the shuffle.shuffle
: If True
, the data is shuffled before splitting. This helps make the train and test sets more random. It’s True
by default.stratify
: This helps keep the same class distribution in both the train and test sets. It’s useful especially for classification problems.
Example
Let us take a sample data to perform splitting of data over it. The data can be downloaded from here in the form of CSV.

In the example, we first import pandas
and sklearn
. Then, we load the CSV file using the read_csv()
function. This stores the data in a DataFrame called df
. we want to predict the house price, which is in the last column so we set that as y
(target). All the other columns are used as features, stored in X
.
We use train_test_split()
to split the data:
test_size=0.05
means 5% of the data is used for testing, and 95% for training.random_state=0
ensures the split is the same every time we run the code.
Python
# import modules
import pandas as pd
from sklearn.linear_model import LinearRegression
from sklearn.model_selection import train_test_split
# read the dataset
df = pd.read_csv('Real-estate.csv')
# get the locations
X = df.iloc[:, :-1]
y = df.iloc[:, -1]
# split the dataset
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.05, random_state=0)
Output:


Hence, we have our splitted dataset into training and testing set. If you want to learn further more about improving your machine learning flow, you may explore :-
- Stratified sampling
- Cross validation
- Handling imbalanced datasets
- Pre-processing before splitting
- Machine Learning Models
Similar Reads
Python Tutorial - Learn Python Programming Language Python is one of the most popular programming languages. Itâs simple to use, packed with features and supported by a wide range of libraries and frameworks. Its clean syntax makes it beginner-friendly. It'sA high-level language, used in web development, data science, automation, AI and more.Known fo
10 min read
Machine Learning Tutorial Machine learning is a branch of Artificial Intelligence that focuses on developing models and algorithms that let computers learn from data without being explicitly programmed for every task. In simple words, ML teaches the systems to think and understand like humans by learning from the data.Machin
5 min read
Non-linear Components In electrical circuits, Non-linear Components are electronic devices that need an external power source to operate actively. Non-Linear Components are those that are changed with respect to the voltage and current. Elements that do not follow ohm's law are called Non-linear Components. Non-linear Co
11 min read
Linear Regression in Machine learning Linear regression is a type of supervised machine-learning algorithm that learns from the labelled datasets and maps the data points with most optimized linear functions which can be used for prediction on new datasets. It assumes that there is a linear relationship between the input and output, mea
15+ min read
Spring Boot Tutorial Spring Boot is a Java framework that makes it easier to create and run Java applications. It simplifies the configuration and setup process, allowing developers to focus more on writing code for their applications. This Spring Boot Tutorial is a comprehensive guide that covers both basic and advance
10 min read
Support Vector Machine (SVM) Algorithm Support Vector Machine (SVM) is a supervised machine learning algorithm used for classification and regression tasks. It tries to find the best boundary known as hyperplane that separates different classes in the data. It is useful when you want to do binary classification like spam vs. not spam or
9 min read
Logistic Regression in Machine Learning Logistic Regression is a supervised machine learning algorithm used for classification problems. Unlike linear regression which predicts continuous values it predicts the probability that an input belongs to a specific class. It is used for binary classification where the output can be one of two po
11 min read
100+ Machine Learning Projects with Source Code [2025] This article provides over 100 Machine Learning projects and ideas to provide hands-on experience for both beginners and professionals. Whether you're a student enhancing your resume or a professional advancing your career these projects offer practical insights into the world of Machine Learning an
5 min read
Class Diagram | Unified Modeling Language (UML) A UML class diagram is a visual tool that represents the structure of a system by showing its classes, attributes, methods, and the relationships between them. It helps everyone involved in a projectâlike developers and designersâunderstand how the system is organized and how its components interact
12 min read
K means Clustering â Introduction K-Means Clustering is an Unsupervised Machine Learning algorithm which groups unlabeled dataset into different clusters. It is used to organize data into groups based on their similarity. Understanding K-means ClusteringFor example online store uses K-Means to group customers based on purchase frequ
4 min read