How to Aggregate Data with PySpark
PySpark has functions for aggregating data in a similar way to SQL. Here you'll learn how to use these functions.
There are two main ways to aggregate data in PySpark:
- Using the groupBy() method’s built-in functions
- Using agg() with PySpark SQL functions
I usually use the agg() function as it gives you more summary functions to choose from, it lets you include more than one summary statistic, and makes naming your summary columns with aliases easy.
However, the first option is good if you’re just after a quick sum or making a basic frequency table to check your data.
Let’s look at both these in more detail now.
1groupBy()
Usually when aggregating data, you group your data by one or more columns and then summarise on other columns.
If you’re familiar with SQL then this is similar to specifying summary functions such as sum(column1)
in the select
statement and grouping columns in the group by
statement.
You can group your data using the groupBy method and then use some of the built-in aggregation methods such as count()
, sum()
, min()
, max()
and avg()
.
For example:
df.groupBy().sum('profit').show()
This will give you the total profit for the entire df
DataFrame as there aren’t any columns specified in the groupBy
method.
The example below will give you the sum
of profit for each region:
df.groupBy('region').sum('profit').show()
You can of course use other aggregation methods such as max
to get the maximum profit for each region by using:
df.groupBy('region').max('profit').show()
One function I use very often for creating a quick frequency table is count
, for example:
active_customers.groupBy('loylty_tier').count().show()
This is useful when you need to get a feel for your data when working with huge DataFrames.
If you want to define a new column first and then get an aggregation of the new column you can use withColumn
before groupBy
, for example:
df.withColumn('weight_kg', df.weight_lb* 0.4535924) \
.groupBy('region') \
.avg('weight_kg').show()
This will give you the average weight in kg by region.
2agg()
The agg()
method lets you use any aggregate functions which are part of the pyspark.sql.functions
module.
📌 Remember:
You can import the functions
module from the pyspark.sql
package like this:
from pyspark.sql import functions as F
Calling this “F” is standard practice and makes sure we can distinguish between PySpark SQL functions and built in Python functions.
Using the functions
module means you aren’t just confined to the basic functions from the GroupedData class. You can now use functions like stddev
and countDistinct
:
flights.groupBy('flight_number').agg(F.stddev('delay')).show()
You can also include more than one aggregate function in the agg
method, for example:
flights.groupBy('flight_number').agg(F.stddev('delay'), F.mean('delay')).show()
You can also apply aliases to our output by using .alias(‘column_alias’)
after each aggregate function:
flights.groupBy('flight_number').agg(F.stddev('delay').alias('delay_stddev'), F.mean('delay').alias('delay_mean').show()
Another useful aggregate function is countDistinct, this one does what it says on the tin! And can be used like this:
df.groupBy('year').agg(F.countDistinct('id').alias('distinct_ids')).show()
❓ Did you know?
You can use your calculated summary in a filter - similar to a having
statement in standard SQL, for example:
df.groupBy('department') \
.agg(sum('salary').alias('sum_salary'), \
.where(col('sum_salary') > 500000) \
.show()
This will return all departments where the total salary is greater than 500k.
🎁 Bonus:
You can use the corr
function to get the correlation coefficient between two columns:
df.agg(F.corr('column1', 'column2').alias('correlation')).show()
There you have it! You can now aggregate your data in PySpark using either groupBy()
or agg()
.
Check out these posts to find out more about querying data with PySpark, and joining DataFrames