PySpark sampleBy using multiple columns
Last Updated :
26 Apr, 2025
In this article, we are going to learn about PySpark sampleBy using multiple columns in Python.
While doing the data processing of the big data. There are many cases where we need a sample of data. In Pyspark, we can get the sample of data by using sampleBy() function to get the sample of data. In this article, we are going to learn how to take samples using multiple columns through sampleBy() function.
sampleBy() function:
The function which returns a stratified sample without replacement based on the fraction given on each stratum is known as sampleBy(). It not only defines strata but also adds sampling by a column.
Syntax: DataFrame.sampleBy(col, fractions, seed=None)
Parameters:
- col: It can be column or string that defines strata.
- fraction: It is the fraction between 0 and 1 according to which sampleBy will be done.
- seed: Random seed (Optional)
Returns: A new DataFrame that represents the stratified sample.
Steps of PySpark sampleBy using multiple columns
Step 1: First of all, import the SparkSession library. The SparkSession library is used to create the session.
from pyspark.sql import SparkSession
Step 2: Now, create a spark session using getOrCreate() function.
spark_session = SparkSession.builder.getOrCreate()
Step 3: Then, either create the data frame using the createDataFrame() function or read the CSV file.
data_frame=csv_file = spark_session.read.csv('#Path of CSV file',
sep = ',', inferSchema = True, header = True)
or
data_frame=spark_session.createDataFrame([(column_data_1), (column_data_2 ), (column_data_3 )],
['column_name_1', 'column_name_2','column_name_3']
Step 4: Later on, store the data frame in another variable as it will be used during sampling.
df=data_frame
Step 5: Further, apply a transformation on every element by defining the columns as well as sampling percentage as an argument in the map() function.
fractions = df.rdd.map(lambda x:
(x[column_index_1],
x[column_index_2])).distinct().map(lambda x:
(x,fraction)).collectAsMap()
Step 6: Moreover, create a tuple of elements using the keyBy() function.
key_df = df.rdd.keyBy(lambda x: (x[column_index_1],x[column_index_2]))
Step 7: Finally, extract random sample through sampleByKey() function using boolean, fraction, and seed as arguments and display the data frame.
key_df.sampleByKey(False,fractions).map(lambda x:
x[column_index_1]).toDF(data_frame.columns).show()
Example 1:
In this example, we have created the data frame with columns ‘Roll_Number,’ ‘Fees‘ and ‘Fine‘, and then extracted the data from it through the sampleByKey() function by boolean, multiple columns (‘Roll_Number‘ and ‘Fees‘) and fraction as arguments. We have extracted the random sample twice through the sampleByKey() function to see if we get the same fractional value each time. What we observed is that we got different values each time.
Python3
from pyspark.sql import SparkSession
spark_session = SparkSession.builder.getOrCreate()
data_frame = spark_session.createDataFrame([( 1 , 10000 , 400 ),
( 2 , 14000 , 500 ),
( 3 , 12000 , 800 )],
[ 'Roll_Number' , 'Fees' , 'Fine' ])
df = data_frame
print ( "Data frame:" )
df.show()
fractions = df.rdd. map ( lambda x:
(x[ 0 ],x[ 1 ])).distinct(). map ( lambda x:
(x, 0.4 )).collectAsMap()
key_df = df.rdd.keyBy( lambda x: (x[ 0 ],x[ 1 ]))
print ( "Sample 1: " )
key_df.sampleByKey( False ,
fractions). map ( lambda x:
x[ 1 ]).toDF(data_frame.columns).show()
print ( "Sample 2: " )
key_df.sampleByKey( False ,
fractions). map ( lambda x:
x[ 1 ]).toDF(data_frame.columns).show()
|
Output:
Data frame:
+-----------+-----+----+
|Roll_Number| Fees|Fine|
+-----------+-----+----+
| 1|10000| 400|
| 2|14000| 500|
| 3|12000| 800|
+-----------+-----+----+
Sample 1:
+-----------+-----+----+
|Roll_Number| Fees|Fine|
+-----------+-----+----+
| 3|12000| 800|
+-----------+-----+----+
Sample 2:
+-----------+-----+----+
|Roll_Number| Fees|Fine|
+-----------+-----+----+
| 3|12000| 800|
+-----------+-----+----+
Example 2:
In this example, we have taken the data frame from the CSV file (link) and then extracted the data from it through the sampleByKey() function by boolean, multiple columns (‘Class,’ ‘Fees‘ and ‘Discount‘), fraction and seed as arguments. We have extracted the random sample twice through the sampleByKey() function to see if we get the same fractional value each time. What we observed is that we got the same values each time.
Python3
from pyspark.sql import SparkSession
spark_session = SparkSession.builder.getOrCreate()
data_frame = csv_file = spark_session.read.csv(
'/content/drive/MyDrive/Colab Notebooks/class_data.csv' ,
sep = ',' , inferSchema = True , header = True )
df = data_frame
print ( "Data frame: " )
df.show()
fractions = df.rdd. map ( lambda x:
(x[ 2 ], x[ 3 ], x[ 4 ])).distinct(). map (
lambda x: (x, 0.4 )).collectAsMap()
key_df = df.rdd.keyBy( lambda x:
(x[ 2 ], x[ 3 ], x[ 4 ]))
print ( "Sample 1: " )
key_df.sampleByKey( True , fractions, 4 ). map (
lambda x: x[ 1 ]).toDF(data_frame.columns).show()
print ( "Sample 2: " )
key_df.sampleByKey( True , fractions, 4 ). map (
lambda x: x[ 1 ]).toDF(data_frame.columns).show()
|
Output:
Data frame:
+-------+--------------+-----+-----+--------+
| name| subject|class| fees|discount|
+-------+--------------+-----+-----+--------+
| Arun| Maths| 10|12000| 400|
| Aniket|Social Science| 11|15000| 600|
| Ishita| English| 9| 9000| 0|
|Pranjal| Science| 12|18000| 1000|
|Vinayak| Computer| 12|18000| 500|
+-------+--------------+-----+-----+--------+
Sample 1:
+------+-------+-----+----+--------+
| name|subject|class|fees|discount|
+------+-------+-----+----+--------+
|Ishita|English| 9|9000| 0|
+------+-------+-----+----+--------+
Sample 2:
+------+-------+-----+----+--------+
| name|subject|class|fees|discount|
+------+-------+-----+----+--------+
|Ishita|English| 9|9000| 0|
+------+-------+-----+----+--------+
Similar Reads
Add Multiple Columns Using UDF in PySpark
In this article, we are going to learn how to add multiple columns using UDF in Pyspark in Python. Have you ever worked on a Pyspark data frame? If yes, then you might surely know how to add a column and you might have also done it. But have you ever thought about how you can add multiple columns us
5 min read
PySpark RDD - Sort by Multiple Columns
In this article, we are going to learn sorting Pyspark RDD by multiple columns in Python. There occurs various situations in being a data scientist when you get unsorted data and there is not only one column unsorted but multiple columns are unsorted. This situation can be overcome by sorting the da
7 min read
PySpark convert multiple columns to map
In this article, we are going to convert multiple columns to map using Pyspark in Python. An RDD transformation that is used to apply the transformation function on every element of the data frame is known as a map. While working in the Pyspark data frame, we might encounter some circumstances in wh
3 min read
Drop a column with same name using column index in PySpark
In this article, we are going to learn how to drop a column with the same name using column index using Pyspark in Python. Pyspark offers you the essential function 'drop' through which you can easily delete one or more columns. But have you ever got the requirement in which you have various columns
3 min read
Apply a transformation to multiple columns PySpark dataframe
In this article, we are going to learn how to apply a transformation to multiple columns in a data frame using Pyspark in Python. The API which was introduced to support Spark and Python language and has features of Scikit-learn and Pandas libraries of Python is known as Pyspark. While using Pyspark
7 min read
Dynamically Rename Multiple Columns in PySpark DataFrame
In this article, we are going to learn how to dynamically rename multiple columns in Pyspark data frame in Python. A data frame that is equivalent to a relational table in Spark SQL, and can be created using various functions in SparkSession is known as Pyspark data frame. While working in Pyspark,
13 min read
How to Create Multiple Lags in PySpark
In this article, we are going to learn how to create multiple lags using pyspark in Python. What is lag in Pyspark? The lag lets our query on more than one row of a table and return the previous row in the table. Have you ever got the need to create multiple lags in Pyspark? Don't know how to achie
4 min read
PySpark - How to Update Nested Columns?
In this article, we are going to learn how to update nested columns using Pyspark in Python. An interface for Apache Spark in Python is known as Pyspark. Do you know that you can create the nested column in the Pyspark data frame too? Not only you can create the nested column, but also you can updat
5 min read
PySpark UDFs with List Arguments
Are you a data enthusiast who works keenly on Python Pyspark data frame? Then, you might know how to link a list of data to a data frame, but do you know how to pass list as a parameter to UDF? Don't know! Read the article further to know about it in detail. PySpark - Pass list as parameter to UDF
4 min read
Add Suffix and Prefix to all Columns in PySpark
In this article, we are going to add suffixes and prefixes to all columns using Pyspark in Python. An open-source, distributed computing framework and set of libraries for real-time, large-scale data processing API primarily developed for Apache Spark, is known as Pyspark. While working in Pyspark,
11 min read