PySpark Big Data Processing: Production Patterns That Scale
- Every .count(), .show(), and .write() you call is an action that triggers a full DAG execution. Every transformation before it is just a plan. Confusing these two concepts is the source of most PySpark performance bugs β including the 'why did my job run twice' mystery.
- Data skew on a join key β not hardware, not network, not Spark version β is the single most common reason production PySpark jobs stall at 99% completion. One key holding 40% of your rows means one task handles 40% of the work while 199 tasks sit idle. Salting or broadcast joins are the only real fixes.
- Reach for PySpark when your data genuinely doesn't fit in memory on a single machine, or when you need fault-tolerant distributed processing. For datasets under 10GB, pandas on a decent machine is faster to develop, faster to run, and dramatically easier to debug. Don't distribute problems that don't need distributing.
A fintech team I worked with spent three weeks tuning a PySpark job that aggregated transaction records for daily risk reports. It ran fine on 10 million rows in staging. At 2 billion rows in production, it silently stalled for six hours, then crashed the cluster with a java.lang.OutOfMemoryError: GC overhead limit exceeded. The root cause wasn't bad code β it was a single unpartitioned join that forced every executor to shuffle 400GB of data through a single node. One line of misunderstood API destroyed a week's worth of cluster credits.
PySpark sits at the intersection of Python's ecosystem and Apache Spark's distributed execution engine. That's a powerful combination, but it's also a trap for developers who treat it like pandas with a bigger machine. Spark doesn't run your Python the way you think it does. Your DataFrame transformations are lazy. Your joins can silently cause catastrophic data skew. Your UDFs are serialized across a JVM boundary in a way that can cut throughput by 10x. Understanding the execution model isn't academic β it's the difference between a job that completes in 8 minutes and one that runs your cloud bill into the thousands.
After this article, you'll know how to configure a SparkSession for production, write transformations that respect Spark's lazy evaluation model, tune partitioning to eliminate shuffle bottlenecks, debug skewed joins using the Spark UI, and write Spark-native aggregations instead of Python UDFs that strangle your executors. Concrete patterns. Runnable code. The failure modes that textbooks skip.
SparkSession Setup and the Execution Model You Must Understand First
Most tutorials show you spark = SparkSession.builder.getOrCreate() and move on. That's like showing someone a car key and skipping the part about combustion engines. Before you write a single transformation, you need to understand what Spark actually does with your code β because it doesn't run it.
Spark uses lazy evaluation. Every transformation you write β filter, select, join, groupBy β builds a logical query plan. Nothing executes until you call an action: show(), count(), write(), collect(). This is why you can chain twenty transformations and Spark will optimize the entire chain before touching a single byte of data. The Catalyst optimizer reorders predicates, prunes unused columns, and sometimes rewrites your join strategy entirely. It's genuinely impressive β until you start debugging and wonder why your print statements inside a map never fire.
The DAG (Directed Acyclic Graph) is Spark's execution blueprint. Each action triggers a job, which splits into stages wherever a shuffle is required, and stages split into tasks that run in parallel across executor cores. Shuffles are expensive because they require data to move across the network between executors. Every wide transformation β groupBy, join, distinct, repartition β causes a shuffle. Narrow transformations β filter, select, withColumn β don't. This distinction drives every performance decision you'll make in production.
# io.thecodeforge β Python tutorial from pyspark.sql import SparkSession from pyspark.sql import functions as F from pyspark.sql.types import StructType, StructField, StringType, DoubleType, TimestampType, LongType # Production SparkSession β never use defaults for anything beyond local testing. # spark.sql.shuffle.partitions defaults to 200, which is catastrophically wrong # for both small datasets (200 near-empty tasks) and massive ones (OOM per task). # Rule of thumb: target 100-200MB of data per partition after a shuffle. spark = ( SparkSession.builder .appName("transaction-risk-aggregator") # Shows up in Spark UI β make it meaningful .config("spark.sql.shuffle.partitions", "400") # Tune to (total_data_GB * 1024 / 200) roughly .config("spark.sql.adaptive.enabled", "true") # AQE: lets Spark re-optimize mid-execution (Spark 3.x) .config("spark.sql.adaptive.coalescePartitions.enabled", "true") # AQE merges small post-shuffle partitions .config("spark.serializer", "org.apache.spark.serializer.KryoSerializer") # Faster than Java serializer .config("spark.sql.autoBroadcastJoinThreshold", "50mb") # Broadcast joins under 50MB β avoids shuffle .getOrCreate() ) # Silence INFO log spam in local development β you want WARN or ERROR in dev spark.sparkContext.setLogLevel("WARN") # Define schema explicitly β NEVER rely on inferSchema in production. # inferSchema scans the entire dataset twice and makes wrong type guesses on messy data. transaction_schema = StructType([ StructField("transaction_id", StringType(), nullable=False), StructField("account_id", StringType(), nullable=False), StructField("merchant_id", StringType(), nullable=False), StructField("amount_usd", DoubleType(), nullable=False), StructField("transaction_ts", TimestampType(), nullable=False), StructField("risk_score", DoubleType(), nullable=True), # nullable=True for fields that may be missing ]) # Read parquet β the production format. CSV is for demos and data exports only. # partitionBy on ingestion (year/month/day) means Spark prunes irrelevant files at read time. transaction_df = ( spark.read .schema(transaction_schema) # Enforce schema β fail fast on bad data .option("mergeSchema", "false") # Don't silently accept schema drift .parquet("/data/transactions/year=2024/month=01/") # Partition-pruned read path ) # This is a lazy transformation β zero data has moved. Zero files have been opened. # Spark just recorded the intent in the query plan. high_risk_transactions = ( transaction_df .filter(F.col("risk_score") > 0.85) # Predicate pushdown: Spark pushes this into the reader .filter(F.col("amount_usd") > 100.0) # Stack filters β Catalyst combines them into one pass .select("transaction_id", "account_id", "merchant_id", "amount_usd", "risk_score") # Column pruning .withColumn("risk_tier", # Derived column β evaluated lazily with the rest F.when(F.col("risk_score") > 0.95, F.lit("CRITICAL")) .when(F.col("risk_score") > 0.85, F.lit("HIGH")) .otherwise(F.lit("MEDIUM")) ) ) # Calling explain() shows the physical plan WITHOUT executing the query. # Use this to verify predicate pushdown is happening and catch accidental cross joins. high_risk_transactions.explain(mode="formatted") # .count() is an action β THIS is when Spark actually reads the files and runs the DAG. record_count = high_risk_transactions.count() print(f"High-risk transactions in January 2024: {record_count:,}") # Cache only if you're reusing the DataFrame multiple times in the same job. # Caching a DataFrame you only use once wastes memory and slows you down. high_risk_transactions.cache() # Stores in memory after first action triggers it # Second action β hits cache instead of re-reading parquet files high_risk_transactions.show(5, truncate=False)
* Filter (1)
+- * ColumnarToRow (2)
+- Scan parquet [transaction_id, account_id, merchant_id, amount_usd, risk_score] (3)
PushedFilters: [IsNotNull(risk_score), IsNotNull(amount_usd), GreaterThan(risk_score,0.85), GreaterThan(amount_usd,100.0)]
PartitionFilters: []
ReadSchema: struct<transaction_id:string,account_id:string,...>
High-risk transactions in January 2024: 142,837
+----------------+----------+-----------+----------+----------+---------+
|transaction_id |account_id|merchant_id|amount_usd|risk_score|risk_tier|
+----------------+----------+-----------+----------+----------+---------+
|TXN-8821937-A |ACC-00421 |MER-9921 |1205.50 |0.97 |CRITICAL |
|TXN-8822041-B |ACC-00887 |MER-1043 |342.00 |0.91 |HIGH |
|TXN-8822198-C |ACC-00421 |MER-0055 |89500.00 |0.99 |CRITICAL |
|TXN-8822301-D |ACC-01204 |MER-4421 |175.25 |0.88 |HIGH |
|TXN-8822509-E |ACC-00034 |MER-9921 |22000.00 |0.96 |CRITICAL |
+----------------+----------+-----------+----------+----------+---------+
only showing top 5 rows
Joins, Shuffles, and the Data Skew Problem That Tanks Production Jobs
Joins are where production PySpark jobs go to die. Not because joins are bad β because developers write them without thinking about what Spark has to do physically to execute them. When you join two DataFrames, Spark needs to get matching keys onto the same executor. That means shuffling data across the network. On a well-distributed dataset, this is fine. On a skewed dataset β where one key represents 40% of your data β one executor gets buried while the other 99 sit idle. Your job appears to be 99% complete for six hours, then either finishes three days late or crashes.
The most common skew pattern I see in production is joining on customer_id or merchant_id in transactional data. Real-world data isn't uniform. Your top merchant processes ten thousand times more transactions than your median merchant. When you join your transactions table against a merchant metadata table on merchant_id, all records for that top merchant route to a single executor. I've personally watched this kill a Spark job at a payments company at 11pm on a Friday β the job had run fine for months, then the top merchant's volume doubled during a flash sale and suddenly one task ran for four hours while 199 tasks completed in two minutes.
Spark 3.x Adaptive Query Execution (AQE) helps here, but it's not magic. You still need to understand broadcast joins, salting strategies, and when to break a complex join into multiple simpler stages.
# io.thecodeforge β Python tutorial from pyspark.sql import SparkSession, functions as F from pyspark.sql.types import StructType, StructField, StringType, DoubleType, IntegerType spark = ( SparkSession.builder .appName("merchant-risk-join") .config("spark.sql.shuffle.partitions", "400") .config("spark.sql.adaptive.enabled", "true") .config("spark.sql.adaptive.skewJoin.enabled", "true") # AQE detects and splits skewed partitions .config("spark.sql.adaptive.skewJoin.skewedPartitionThresholdInBytes", "256mb") # Flag partitions over 256MB as skewed .getOrCreate() ) # --- SCENARIO: Join 2B transactions against 50K merchant records --- # merchant_lookup is small (~5MB). transactions_df is massive (~800GB). merchant_schema = StructType([ StructField("merchant_id", StringType(), nullable=False), StructField("merchant_name", StringType(), nullable=True), StructField("merchant_category", StringType(), nullable=True), StructField("country_code", StringType(), nullable=True), ]) transaction_schema = StructType([ StructField("transaction_id", StringType(), nullable=False), StructField("account_id", StringType(), nullable=False), StructField("merchant_id", StringType(), nullable=False), StructField("amount_usd", DoubleType(), nullable=False), ]) merchant_df = spark.read.schema(merchant_schema).parquet("/data/merchants/") transaction_df = spark.read.schema(transaction_schema).parquet("/data/transactions/") # --------------------------------------------------------------- # APPROACH 1: Broadcast join β the right move when one side is small. # Spark ships the entire small DataFrame to every executor. # Zero shuffle on the large side. This is the fastest possible join. # Use when the smaller DataFrame fits in executor memory (typically under 100MB). # --------------------------------------------------------------- merchant_enriched_df = transaction_df.join( F.broadcast(merchant_df), # Explicitly broadcast the small side β don't rely on autoBroadcastJoinThreshold alone on="merchant_id", how="left" # left join: keep all transactions even if merchant metadata is missing ) # --------------------------------------------------------------- # APPROACH 2: Salting β the manual fix for a skewed shuffle join # when BOTH sides are too large to broadcast. # # Skew example: merchant MER-0001 has 40% of all transactions. # Without salting, one executor handles 40% of the data. With salting, # we artificially split that hot key across N buckets. # --------------------------------------------------------------- SALT_BUCKETS = 20 # Spread the hot key across 20 partitions instead of 1 # Add a random salt to the large (skewed) side transaction_salted_df = transaction_df.withColumn( "salt", (F.rand() * SALT_BUCKETS).cast(IntegerType()) # Random int 0β19 per row ).withColumn( "salted_merchant_id", F.concat(F.col("merchant_id"), F.lit("_"), F.col("salt").cast(StringType())) # "MER-0001_7" ) # Explode the small side to match all salt values # This increases merchant_df by SALT_BUCKETS times β acceptable when it's small merchant_exploded_df = merchant_df.withColumn( "salt", F.explode(F.array([F.lit(i) for i in range(SALT_BUCKETS)])) # One row per bucket: 0,1,2...19 ).withColumn( "salted_merchant_id", F.concat(F.col("merchant_id"), F.lit("_"), F.col("salt").cast(StringType())) # "MER-0001_7" ) # Now join on the salted key β hot key is distributed across 20 partitions skew_fixed_df = ( transaction_salted_df.join( merchant_exploded_df, on="salted_merchant_id", how="left" ) .drop("salt", "salted_merchant_id") # Clean up the salt columns post-join ) # --------------------------------------------------------------- # Verify the join result β check for nulls that indicate join misses # --------------------------------------------------------------- join_miss_count = skew_fixed_df.filter(F.col("merchant_name").isNull()).count() total_count = skew_fixed_df.count() print(f"Total transactions enriched: {total_count:,}") print(f"Transactions with missing merchant data: {join_miss_count:,}") print(f"Enrichment rate: {((total_count - join_miss_count) / total_count) * 100:.2f}%") # Check for data skew in the final partition distribution # High stddev relative to mean = skew still present skew_fixed_df.groupBy(F.spark_partition_id()).count().describe().show()
Transactions with missing merchant data: 12,441
Enrichment rate: 99.999%
+-------+--------------------+--------------------+
|summary|partition_id |count |
+-------+--------------------+--------------------+
|count |400 |400 |
|mean |199.5 |5,119,578.73 |
|stddev |115.37 |48,221.14 |
|min |0 |4,891,204 |
|max |399 |5,388,901 |
+-------+--------------------+--------------------+
-- stddev/mean ratio ~0.9% indicates healthy partition distribution after salting
-- Before salting: one partition contained 818M rows (40% of dataset)
Aggregations Without UDFs: Why Native Spark Functions Outperform Python 10x
Here's what I see constantly from developers coming from pandas: they hit a transformation that's slightly complex, they can't immediately find the built-in Spark function, and they write a Python UDF. It feels natural. It works in testing. Then it hits production at scale and your job takes four times longer than it should.
The reason is the JVM boundary. Spark's execution engine runs on the JVM. Your Python UDFs run in a separate Python process on each executor. For every batch of rows, Spark has to serialize data from JVM memory, ship it across a local socket to the Python process, execute your Python code, serialize the results back, and deserialize them into JVM memory. This round-trip happens millions of times. I've measured a trivially simple string transformation running 8x slower as a Python UDF than as a native Spark SQL function call.
The fix is to learn pyspark.sql.functions deeply. It covers 95% of what you'd ever want to do with a UDF. Window functions handle running totals, rankings, and lag/lead calculations. Higher-order functions (transform, filter, aggregate) handle array and map columns. When you genuinely need custom logic that has no Spark equivalent, use Pandas UDFs (also called vectorized UDFs) β they batch rows into pandas DataFrames using Apache Arrow, which eliminates the per-row serialization cost and typically runs within 2x of native Spark performance.
# io.thecodeforge β Python tutorial from pyspark.sql import SparkSession, functions as F, Window from pyspark.sql.types import StructType, StructField, StringType, DoubleType, TimestampType, IntegerType import pandas as pd spark = ( SparkSession.builder .appName("account-behaviour-features") .config("spark.sql.shuffle.partitions", "200") .config("spark.sql.adaptive.enabled", "true") .getOrCreate() ) transaction_schema = StructType([ StructField("transaction_id", StringType(), nullable=False), StructField("account_id", StringType(), nullable=False), StructField("merchant_id", StringType(), nullable=False), StructField("amount_usd", DoubleType(), nullable=False), StructField("transaction_ts", TimestampType(), nullable=False), ]) transaction_df = spark.read.schema(transaction_schema).parquet("/data/transactions/") # --------------------------------------------------------------- # NATIVE AGGREGATIONS β no UDFs needed for standard feature engineering # --------------------------------------------------------------- # Window spec: per account, ordered by time, unbounded lookback # This is the pattern for running totals, rolling averages, lag features account_time_window = ( Window .partitionBy("account_id") # Each account gets its own independent window .orderBy("transaction_ts") # Ordered chronologically within the partition .rowsBetween(Window.unboundedPreceding, Window.currentRow) # All rows up to current ) # 30-day rolling window β for detecting velocity anomalies account_30day_window = ( Window .partitionBy("account_id") .orderBy(F.col("transaction_ts").cast("long")) # Cast to epoch seconds for range arithmetic .rangeBetween(-30 * 24 * 3600, 0) # 30 days in seconds β rangeBetween works on the ORDER BY column value ) feature_df = ( transaction_df .withColumn( "account_lifetime_spend", F.sum("amount_usd").over(account_time_window) # Running total spend per account ) .withColumn( "account_txn_rank", F.rank().over(account_time_window) # Transaction sequence number within account ) .withColumn( "rolling_30d_spend", F.sum("amount_usd").over(account_30day_window) # 30-day rolling spend ) .withColumn( "rolling_30d_txn_count", F.count("transaction_id").over(account_30day_window) # Transaction velocity in 30 days ) .withColumn( "prev_transaction_amount", F.lag("amount_usd", 1).over( Window.partitionBy("account_id").orderBy("transaction_ts") # Previous txn amount for delta calculation ) ) .withColumn( "amount_delta_vs_prev", F.col("amount_usd") - F.coalesce(F.col("prev_transaction_amount"), F.lit(0.0)) # coalesce handles first txn (null prev) ) .withColumn( "is_large_transaction", F.when(F.col("amount_usd") > F.col("rolling_30d_spend") * 0.5, F.lit(True)) # >50% of 30d spend in one txn .otherwise(F.lit(False)) ) ) # --------------------------------------------------------------- # PANDAS UDF β use ONLY when native functions genuinely can't do it. # Example: calling a custom ML scorer or regex library. # Vectorized: receives a pd.Series, returns a pd.Series. Arrow handles serialization. # DO NOT use row-level Python UDFs (udf decorator) β they kill throughput. # --------------------------------------------------------------- from pyspark.sql.functions import pandas_udf from pyspark.sql.types import DoubleType @pandas_udf(DoubleType()) # Vectorized UDF β Arrow-serialized batches, not row-by-row def compute_velocity_score(rolling_count: pd.Series, rolling_spend: pd.Series) -> pd.Series: """ Custom velocity scoring that combines transaction count and spend rate. This is genuinely custom logic β no Spark native equivalent. Receives full column batches as pandas Series β highly efficient. """ # Normalize: high count + high spend = high velocity risk count_score = (rolling_count / rolling_count.max()).fillna(0) # Relative count percentile spend_score = (rolling_spend / rolling_spend.max()).fillna(0) # Relative spend percentile return (count_score * 0.4 + spend_score * 0.6).round(4) # Weighted composite score scored_df = feature_df.withColumn( "velocity_risk_score", compute_velocity_score( F.col("rolling_30d_txn_count").cast(DoubleType()), F.col("rolling_30d_spend") ) ) # Write partitioned output β partitionBy creates a directory hierarchy for efficient reads ( scored_df .repartition(200, "account_id") # Co-locate same account on same partition before write .write .mode("overwrite") .partitionBy("account_id") # Downstream reads can filter on account_id with zero shuffle .parquet("/data/features/account_behaviour/") ) print("Feature generation complete. Verifying output...") output_df = spark.read.parquet("/data/features/account_behaviour/") output_df.select( "transaction_id", "account_id", "amount_usd", "rolling_30d_spend", "rolling_30d_txn_count", "is_large_transaction", "velocity_risk_score" ).orderBy("account_id", "transaction_ts").show(8, truncate=False)
+----------------+----------+----------+----------------+----------------------+--------------------+-------------------+
|transaction_id |account_id|amount_usd|rolling_30d_spend|rolling_30d_txn_count|is_large_transaction|velocity_risk_score|
+----------------+----------+----------+----------------+----------------------+--------------------+-------------------+
|TXN-1001-A |ACC-00001 |45.00 |45.00 |1 |false |0.0012 |
|TXN-1002-A |ACC-00001 |320.00 |365.00 |2 |false |0.0089 |
|TXN-1003-A |ACC-00001 |5200.00 |5565.00 |3 |true |0.1423 |
|TXN-2001-B |ACC-00002 |1200.00 |1200.00 |1 |false |0.0341 |
|TXN-2002-B |ACC-00002 |980.00 |2180.00 |2 |false |0.0612 |
|TXN-3001-C |ACC-00003 |89500.00 |89500.00 |1 |false |0.9871 |
|TXN-3002-C |ACC-00003 |125000.00 |214500.00 |2 |true |0.9998 |
|TXN-4001-D |ACC-00004 |22.50 |22.50 |1 |false |0.0003 |
+----------------+----------+----------+----------------+----------------------+--------------------+-------------------+
Writing to Production Storage: Partition Strategy, Output Modes, and Avoiding the Small Files Problem
Writing Spark output correctly is just as important as reading and processing correctly, and it's where I see the most rookie mistakes land in production. The most insidious one: after all your careful processing, you write out with the default partition count β or worse, you repartition(1) because you want a single output file β and you've just created either thousands of tiny 1KB files or one massive unparallelizable blob.
The small files problem is real and painful. HDFS and object stores like S3 weren't designed for millions of tiny files. Each file carries metadata overhead. AWS S3 LIST operations are expensive and slow. Downstream Spark jobs reading your output have to open one file handle per partition file β if you wrote 10,000 partitions with 3 tasks each, your downstream reader opens 30,000 files just to start processing. I've seen a single poorly-partitioned write turn a downstream job's startup time from 8 seconds to 12 minutes.
The right approach is intentional: partition your output by the dimensions your downstream queries actually filter on, target 100-500MB per output file, and use coalesce (not repartition) when you need to reduce partition count without a shuffle. For streaming or incremental pipelines, understand the difference between overwrite, append, and the Delta Lake / Iceberg merge patterns β because overwrite on partitioned data can silently delete partitions you didn't intend to touch.
# io.thecodeforge β Python tutorial from pyspark.sql import SparkSession, functions as F from pyspark.sql.types import StructType, StructField, StringType, DoubleType, TimestampType spark = ( SparkSession.builder .appName("daily-risk-report-writer") .config("spark.sql.shuffle.partitions", "400") .config("spark.sql.adaptive.enabled", "true") .config("spark.sql.adaptive.coalescePartitions.enabled", "true") .config("spark.sql.adaptive.coalescePartitions.minPartitionSize", "128mb") # AQE won't merge below this .getOrCreate() ) transaction_schema = StructType([ StructField("transaction_id", StringType(), nullable=False), StructField("account_id", StringType(), nullable=False), StructField("merchant_id", StringType(), nullable=False), StructField("amount_usd", DoubleType(), nullable=False), StructField("transaction_ts", TimestampType(), nullable=False), StructField("risk_tier", StringType(), nullable=True), StructField("country_code", StringType(), nullable=True), ]) processed_df = spark.read.schema(transaction_schema).parquet("/data/enriched_transactions/") # Add partition columns before writing β these become directory names in the output path dated_df = processed_df.withColumn( "report_date", F.to_date(F.col("transaction_ts")) # Extract date part for daily partitioning ).withColumn( "report_hour", F.hour(F.col("transaction_ts")) # Hour-level partitioning for intraday queries ) # --------------------------------------------------------------- # APPROACH 1: Large dataset β repartition to control output file size. # repartition() DOES cause a full shuffle. Only use when partition count # and distribution matter more than the shuffle cost. # Target: ~200MB per output file for S3/HDFS performance # --------------------------------------------------------------- # Check current partition count before writing current_partitions = dated_df.rdd.getNumPartitions() print(f"Current partition count before write: {current_partitions}") # Repartition by the SAME columns you're partitioning the output by. # This co-locates rows that will land in the same output directory, # which means each task writes to exactly one output file β no file merging needed. dated_df_repartitioned = dated_df.repartition( 200, # 200 output files per (report_date, risk_tier) combination "report_date", # Primary distribution key matches partitionBy key "risk_tier" # Secondary distribution key ) # Write with dynamic partition overwrite β ONLY overwrites the specific partitions present # in this DataFrame, not the entire output table. # Without this setting, mode('overwrite') nukes the ENTIRE output path. spark.conf.set("spark.sql.sources.partitionOverwriteMode", "dynamic") ( dated_df_repartitioned .write .mode("overwrite") # Dynamic overwrite: only affects partitions in this batch .partitionBy("report_date", "risk_tier") # Creates /report_date=2024-01-15/risk_tier=HIGH/ directories .option("compression", "snappy") # Snappy: good balance of compression ratio vs CPU cost .parquet("/data/risk_reports/daily/") ) # --------------------------------------------------------------- # APPROACH 2: Small-to-medium result set β coalesce to avoid small files. # coalesce() reduces partition count WITHOUT a full shuffle (narrow transformation). # Use when you want fewer, larger output files and shuffle cost matters. # --------------------------------------------------------------- daily_summary_df = ( processed_df .groupBy("country_code", "risk_tier", F.to_date("transaction_ts").alias("report_date")) .agg( F.count("transaction_id").alias("transaction_count"), F.sum("amount_usd").alias("total_amount_usd"), F.avg("amount_usd").alias("avg_amount_usd"), F.countDistinct("account_id").alias("unique_accounts") ) ) # After groupBy, Spark has 400 shuffle partitions β but this summary is small (~100K rows). # coalesce to a reasonable number of files without paying for a re-shuffle. daily_summary_df.coalesce(10) # 10 files for a small summary dataset is fine ( daily_summary_df.coalesce(10) .write .mode("overwrite") .option("compression", "snappy") .parquet("/data/risk_reports/daily_summary/") ) # --------------------------------------------------------------- # Validate the output: check file count and size distribution # In production, wire this into your data quality framework # --------------------------------------------------------------- output_df = spark.read.parquet("/data/risk_reports/daily_summary/") output_df.printSchema() output_df.orderBy("report_date", "country_code", "risk_tier").show(10) print(f"Output partition count: {output_df.rdd.getNumPartitions()}") print(f"Total summary records: {output_df.count():,}")
root
|-- country_code: string (nullable = true)
|-- risk_tier: string (nullable = true)
|-- report_date: date (nullable = true)
|-- transaction_count: long (nullable = false)
|-- total_amount_usd: double (nullable = true)
|-- avg_amount_usd: double (nullable = true)
|-- unique_accounts: long (nullable = false)
+------------+---------+-----------+-----------------+------------------+-----------------+---------------+
|country_code|risk_tier|report_date|transaction_count|total_amount_usd |avg_amount_usd |unique_accounts|
+------------+---------+-----------+-----------------+------------------+-----------------+---------------+
|GB |CRITICAL |2024-01-15 |8,421 |42,891,330.00 |5,093.71 |1,204 |
|GB |HIGH |2024-01-15 |31,882 |19,445,200.50 |610.02 |8,891 |
|GB |MEDIUM |2024-01-15 |284,991 |71,200,450.25 |249.83 |142,003 |
|US |CRITICAL |2024-01-15 |22,014 |198,320,880.00 |9,008.77 |3,421 |
|US |HIGH |2024-01-15 |89,441 |55,124,990.75 |616.38 |24,882 |
|US |MEDIUM |2024-01-15 |891,203 |224,531,200.00 |251.93 |401,221 |
|DE |CRITICAL |2024-01-15 |4,110 |18,440,210.00 |4,486.67 |891 |
|DE |HIGH |2024-01-15 |14,221 |8,920,440.50 |627.29 |4,201 |
|DE |MEDIUM |2024-01-15 |102,441 |25,841,003.75 |252.26 |52,001 |
|SG |CRITICAL |2024-01-15 |1,204 |10,420,800.00 |8,656.81 |312 |
+------------+---------+-----------+-----------------+------------------+-----------------+---------------+
Output partition count: 10
Total summary records: 2,847
| Aspect | Python UDF (udf decorator) | Pandas UDF (pandas_udf) | Native Spark Function (F.*) |
|---|---|---|---|
| Serialization overhead | Per-row Python β JVM round-trip | Batch Arrow serialization | None β runs inside JVM |
| Typical throughput vs native | 5xβ20x slower | 1.5xβ3x slower | Baseline (fastest) |
| Use case fit | Legacy code only β avoid | Custom ML scoring, complex regex | 95% of production transformations |
| Null handling | Must handle None explicitly in Python | Must handle NaN/None in pandas | Built-in null propagation |
| Pushdown / optimization | No β opaque to Catalyst optimizer | No β opaque to Catalyst optimizer | Yes β Catalyst can optimize |
| Debugging experience | Stack traces cross JVM/Python boundary | Pandas exceptions surface clearly | Clear Spark plan errors |
| When AQE helps | No β bottleneck is serialization | Partially β batching helps | Yes β full AQE benefits apply |
| Type safety | Runtime type errors only | Runtime type errors only | Compile-time schema checks |
π― Key Takeaways
- Every .count(), .show(), and .write() you call is an action that triggers a full DAG execution. Every transformation before it is just a plan. Confusing these two concepts is the source of most PySpark performance bugs β including the 'why did my job run twice' mystery.
- Data skew on a join key β not hardware, not network, not Spark version β is the single most common reason production PySpark jobs stall at 99% completion. One key holding 40% of your rows means one task handles 40% of the work while 199 tasks sit idle. Salting or broadcast joins are the only real fixes.
- Reach for PySpark when your data genuinely doesn't fit in memory on a single machine, or when you need fault-tolerant distributed processing. For datasets under 10GB, pandas on a decent machine is faster to develop, faster to run, and dramatically easier to debug. Don't distribute problems that don't need distributing.
- Native Spark SQL functions (pyspark.sql.functions) run inside the JVM and are fully optimizable by Catalyst. Python UDFs run in a separate process with per-row serialization overhead. If you write a Python UDF and there's an equivalent F.* function, you're paying a 5xβ20x tax for nothing β and Catalyst can't see inside your UDF to optimize around it.
β Common Mistakes to Avoid
- βMistake 1: Calling .collect() on a large DataFrame to 'check the results' β AnalysisException or silent OOM as all distributed data is pulled into the driver's heap. On 10GB+ DataFrames, this kills the driver process. Use .show(20), .limit(100).toPandas(), or .describe() instead β they leave data on the executors.
- βMistake 2: Using spark.sql.shuffle.partitions=200 (the default) for datasets over 100GB β each post-shuffle partition is 500MB+ which causes executor OOM and task retries. Set it to ceil(dataset_size_in_MB / 200) and enable AQE with adaptive coalescing so Spark can merge small partitions automatically.
- βMistake 3: Writing a Python UDF for string transformations like uppercasing, trimming, or regex matching β throughput drops 8xβ15x with no benefit. Every string operation you'd write in a Python UDF has a direct equivalent in pyspark.sql.functions: F.upper(), F.trim(), F.regexp_extract(), F.regexp_replace(). Check the docs before reaching for udf().
- βMistake 4: Caching a DataFrame that's only used once β wastes executor memory, causes cache evictions for DataFrames you actually need cached, and adds serialization overhead. Cache only when a DataFrame is referenced by two or more downstream actions in the same job. Call .unpersist() explicitly after you're done with it.
- βMistake 5: Using repartition() when you only need coalesce() β repartition() always causes a full shuffle across the network. coalesce() merges existing partitions without network I/O. If you're reducing partition count (e.g., from 400 to 10 before writing a small summary), coalesce is almost always cheaper. Use repartition only when you need to increase partition count or redistribute by a specific key.
Interview Questions on This Topic
- QYou have a PySpark job where one groupBy stage consistently runs at 99% completion for two hours while all other tasks finish in five minutes. What's happening, how do you confirm it, and what are your mitigation options given both sides of the join are too large to broadcast?
- QWhen would you choose a SortMergeJoin over a BroadcastHashJoin in a production Spark pipeline, and what signals in the Spark UI would tell you that your current join strategy is wrong?
- QWhat happens when you use a Python UDF inside a Spark Structured Streaming job on a stateful aggregation β and what's the production failure mode that bites teams who only test this on batch data?
- QExplain Spark's lazy evaluation and how it interacts with the Catalyst optimizer. Given a pipeline with five chained transformations ending in a write, what specific optimization does Catalyst apply and how would you use explain() to verify it's working?
Frequently Asked Questions
How do I choose the right number for spark.sql.shuffle.partitions in PySpark?
Target 100β200MB of data per partition after a shuffle. Take your post-shuffle data size in MB and divide by 150 β that's your starting point. A 300GB shuffle should use roughly 2,000 partitions, not the default 200. With Spark 3.x and AQE enabled (spark.sql.adaptive.coalescePartitions.enabled=true), Spark will automatically merge undersized partitions at runtime, so erring on the high side is safer than erring on the low side β too few partitions causes OOM, too many causes small-task overhead which AQE can fix.
What's the difference between repartition() and coalesce() in PySpark?
repartition() performs a full shuffle across the network and can increase or decrease partition count. coalesce() merges existing partitions without a shuffle and can only decrease partition count. Use coalesce when you're reducing partitions before a write and don't need even distribution β it's cheaper because there's no network transfer. Use repartition when you need to redistribute data by a specific key or increase partition count, accepting the shuffle cost as necessary.
How do I debug a PySpark job that's stuck at 99% completion?
Open the Spark UI (default port 4040 on the driver) and navigate to the Stages tab. Find the stage stuck at 99% and look at the task duration histogram β if one task is running 100x longer than the median, you have data skew. Click on that task and look at the 'Input Size / Records' column to confirm one partition has massively more data than others. The fix is either a broadcast join if the smaller side fits in memory, or salting the join key if both sides are large. Also check 'GC Time' in the task metrics β if GC is consuming over 20% of task time, you have a memory pressure problem, not a skew problem.
Why does my PySpark job perform fine on 10 million rows in staging but crash on 2 billion rows in production?
Three failure modes cause this almost every time. First, a join or groupBy that produces data skew invisible at small scale β one key representing 0.001% of staging data might represent 40% of production data. Second, spark.sql.shuffle.partitions=200 working fine for small data but creating 10GB partitions at full scale that OOM individual executors. Third, a Python UDF that's slow but tolerable at small scale becoming a throughput bottleneck at full scale β what takes 2 minutes on 10M rows takes 400 minutes on 2B rows because UDF overhead is per-row. Always test with at minimum 10% of production data volume, not 0.5%, and always check the Spark UI partition size distribution after every shuffle stage.
Developer and founder of TheCodeForge. I built this site because I was tired of tutorials that explain what to type without explaining why it works. Every article here is written to make concepts actually click.