Posting this on the off-chance it's useful to someone else...
On my current project, we have a tonne of Spark logic which needs to aggregate and (generally) sum data which is grouped by different keys, in order to then apply business rules at the record level (i.e. if the total of amount x, when grouped by a, b and c, is over a threshold then apply one calculation to each record in the group, otherwise apply another - that sort of thing). Window functions are an obvious way to do this. Consider you have employee data and want to get the total salary grouped by employment status and gender, a window example in PySpark might look something like this:
from pyspark.sql import functions as f
from pyspark.sql.types import IntegerType
window = Window.partitionBy("employment_status", "gender")
df.withColumn("total_salary", f.sum("annual_salary").over(window)) \
.select("first_name", "last_name", "gender", "employment_status", "annual_salary", "total_salary") \
.orderBy("last_name", "first_name") \
.show()
+----------+----------+------+-----------------+-------------+------------+
|first_name| last_name|gender|employment_status|annual_salary|total_salary|
+----------+----------+------+-----------------+-------------+------------+
| | | | | | 3320667.0|
| Edwin| Abbott| M| | 139549| 8286529.0|
| Kazuko| Abbott| F| PE| 58642| 2.4893626E7|
Which works fine, but if you have a huge volume of data (say in the 100's of millions -- obviously we're not likely to be employee data now, but it was easier to find test employee data for these examples 🍀), and you don't have a good distribution of records in each group, then you're likely to hit issues with that skewed data -- where spark is trying to send the data from one group to a particular executor to accomplish the task and the executor blows as a consequence. Particularly if your data is wide (a large number of columns).
Enter salting.
There are a number of articles on salting (like this one on Medium), but the basic principle is to bucket the data in a group using a column with a random number (the salt) -- lets say between 1 and 100 -- so that no single group ends up with too many rows, and then do two passes, aggregation with the salt, and then without. Here's an example of the first pass:
salted_window = Window.partitionBy("employment_status", "gender", "salt") \
.orderBy("employment_status", "gender", "salt")
df.withColumn("salt", f.lit(saltval).cast(IntegerType())) \
.withColumn("row_number", row_number().over(salted_window)) \
.withColumn("salted_total_salary",
f.when(f.col("row_number") == 1,
sum(f.col("annual_salary")).over(salted_window)).otherwise(lit(0))) \
.show()
We're also using row_number
in this example, so that only the first record in the group will get the salted total (otherwise we'll get the same number with each record in a group, which causes gross-up problems later). Now we could just sum up the salted total salary using the "unsalted" window:
salted_window = Window.partitionBy("employment_status", "gender", "salt") \
.orderBy("employment_status", "gender", "salt")
window = Window.partitionBy("employment_status", "gender")
df.withColumn("salt", f.lit(saltval).cast(IntegerType())) \
.withColumn("row_number", row_number().over(salted_window)) \
.withColumn("salted_total_salary",
f.when(f.col("row_number") == 1,
sum(f.col("annual_salary")).over(salted_window))) \
.withColumn("total_salary", f.sum("salted_total_salary").over(window)) \
.select("first_name", "last_name", "gender", "employment_status", "annual_salary", "total_salary") \
.orderBy("last_name", "first_name") \
.show()
The problem with this is that there's no reduction step here (ie. reducing the amount of data spark has to deal with) so immediately unsalting the data simply means all your data (still skewed), plus even more columns (the row number and the salt) is shuffled to an executor. We've not fixed the original problem at all (in fact, made it slightly worse). This should've been obvious looking at that Medium article I linked to above -- each of those examples have a filter to reduce the data volume. We missed that nugget -- but in fact, performing salted, followed by unsalted, window functions did seem to make some difference (at least some of our data sets managed to squeak through) when testing.
So what's the answer, if you need a total with each of your rows and are hitting problems with skew?
Don't use window functions for these cases. Instead, stick with groupBy and agg, then join back to your original data to use the totals:
from pyspark.sql import functions as f
from pyspark.sql.types import IntegerType
saltval = f.round(f.rand() * 100)
df1 = df.withColumn("salt", f.lit(saltval).cast(IntegerType())) \
.groupBy("employment_status", "gender", "salt") \
.agg(f.sum("annual_salary").alias("salted_total_salary")) \
.groupBy("employment_status", "gender") \
.agg(f.sum("salted_total_salary").alias("total_salary")) \
.orderBy("employment_status", "gender") \
.alias("df1")
df.join(df1, (df.employment_status == df1.employment_status) &
(df.gender == df1.gender), "leftouter") \
.select("first_name", "last_name", "df1.gender", "df1.employment_status", "annual_salary", "total_salary") \
.orderBy("last_name", "first_name") \
.show()
Now we are aggregating with the salt and then without - so each step is the data reduction that spark needs to be able to deal with the volume of data in the groups (side note: a broadcast can also be useful in the above).