import pandas as pd
from pyspark.sql import functions as F, Window
from pyspark.sql.types import (
StructType, StructField, StringType, IntegerType, LongType,
DoubleType, TimestampType,
)
from delta.tables import DeltaTable
# =================================================================
# READ
# =================================================================
df = spark.read.table("catalog.schema.my_table")
df = spark.read.format("parquet").load("abfss://container@account.dfs.core.windows.net/path/")
# CSV with explicit schema (avoid inferSchema on large data)
schema = StructType([
StructField("id", IntegerType(), False),
StructField("name", StringType(), True),
StructField("ts", TimestampType(), True),
])
df = (spark.read.format("csv")
.option("header", "true")
.schema(schema)
.load("/Volumes/catalog/schema/vol/data.csv"))
# =================================================================
# SELECT / RENAME / CAST
# =================================================================
df.select("id", F.col("name").alias("user_name"), F.col("ts").cast("date").alias("day"))
df.withColumnRenamed("ts", "event_ts")
df.withColumn("amount", F.col("amount").cast("double"))
df.drop("internal_col", "debug_col")
# =================================================================
# FILTER / WHEN
# =================================================================
df.filter((F.col("status") == "active") & (F.col("amount") > 0))
df.withColumn("tier",
F.when(F.col("amount") >= 1000, "gold")
.when(F.col("amount") >= 100, "silver")
.otherwise("bronze"))
# =================================================================
# GROUP BY / AGGREGATE
# =================================================================
(df.groupBy("region", "day")
.agg(F.sum("amount").alias("total"),
F.countDistinct("user_id").alias("users"),
F.avg("amount").alias("avg_amount")))
# =================================================================
# JOINS
# =================================================================
df.join(dim, on="id", how="left")
df.join(dim, df.user_id == dim.id, "inner").drop(dim.id)
# Broadcast the small side to avoid shuffle
from pyspark.sql.functions import broadcast
df.join(broadcast(small_dim), on="id", how="left")
# =================================================================
# NULLS / DEDUP
# =================================================================
df.na.fill({"amount": 0, "name": "unknown"})
df.na.drop(subset=["id"])
df.dropDuplicates(["user_id", "event_ts"])
# =================================================================
# EXPLODE / ARRAYS / STRUCTS
# =================================================================
df.withColumn("item", F.explode("items")) # array -> rows
df.select("id", F.col("payload.user.email").alias("email")) # nested struct access
df.withColumn("tags_str", F.concat_ws(",", F.col("tags"))) # array -> string
# =================================================================
# PERFORMANCE HELPERS
# =================================================================
df.repartition(200, "region") # increase parallelism, key-aware
df.coalesce(1) # reduce partitions without full shuffle
df.cache(); df.count() # materialize cache
spark.conf.set("spark.sql.shuffle.partitions", "auto") # AQE handles it on DBR
# =================================================================
# WINDOW FUNCTIONS
# =================================================================
# --- Ranking family (note the differences) ---
w = Window.partitionBy("category").orderBy(F.col("score").desc())
df.withColumn("row_number", F.row_number().over(w)) # 1,2,3,4 — no ties
df.withColumn("rank", F.rank().over(w)) # 1,2,2,4 — gaps after ties
df.withColumn("dense_rank", F.dense_rank().over(w)) # 1,2,2,3 — no gaps
df.withColumn("pct_rank", F.percent_rank().over(w)) # relative position 0..1
df.withColumn("ntile_4", F.ntile(4).over(w)) # quartile buckets
# Top-N per group
df.withColumn("rn", F.row_number().over(w)).filter(F.col("rn") <= 3).drop("rn")
# --- Lag / Lead — compare to neighbouring rows ---
w_time = Window.partitionBy("user_id").orderBy("event_ts")
df = (df
.withColumn("prev_amount", F.lag("amount", 1).over(w_time))
.withColumn("next_amount", F.lead("amount", 1).over(w_time))
.withColumn("delta", F.col("amount") - F.col("prev_amount"))
.withColumn("pct_change",
(F.col("amount") - F.col("prev_amount")) / F.col("prev_amount")))
# Time since previous event (seconds)
df.withColumn("gap_sec",
F.col("event_ts").cast("long") - F.lag("event_ts").over(w_time).cast("long"))
# --- First / Last in a window (ignore nulls) ---
df.withColumn("first_status", F.first("status", ignorenulls=True).over(w_time))
df.withColumn("last_status", F.last("status", ignorenulls=True).over(
w_time.rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)))
# --- Rolling windows — by rows vs by range ---
w_rows = w_time.rowsBetween(-6, 0) # last 7 rows
df.withColumn("ma7_rows", F.avg("amount").over(w_rows))
w_days = (Window.partitionBy("user_id") # last 7 days
.orderBy(F.col("event_ts").cast("long"))
.rangeBetween(-7 * 86400, 0))
df.withColumn("sum_7d", F.sum("amount").over(w_days))
# --- Sessionization — new session after 30 min inactivity ---
gap = 30 * 60
df = (df
.withColumn("prev_ts", F.lag("event_ts").over(w_time))
.withColumn("new_session",
((F.col("event_ts").cast("long") - F.col("prev_ts").cast("long")) > gap)
.cast("int"))
.fillna({"new_session": 1})
.withColumn("session_id",
F.sum("new_session").over(w_time.rowsBetween(Window.unboundedPreceding, 0))))
# --- Gaps & islands — collapse consecutive same-status runs ---
df = (df
.withColumn("grp",
F.row_number().over(w_time)
- F.row_number().over(Window.partitionBy("user_id", "status").orderBy("event_ts")))
# then group by (user_id, status, grp) for one row per contiguous run
)
# =================================================================
# UDFs & PANDAS APIS (slowest -> fastest)
# Prefer native functions first; reach for these only when needed.
# =================================================================
# --- 1. Plain Python UDF — simplest, slowest (row-at-a-time) ---
@F.udf(returnType=StringType())
def normalize(s):
return s.strip().lower() if s else None
df.withColumn("name_norm", normalize("name"))
# --- 2. Pandas (vectorized) UDF — Arrow-based, much faster ---
@F.pandas_udf(DoubleType())
def celsius_to_f(c: pd.Series) -> pd.Series:
return c * 9 / 5 + 32
df.withColumn("temp_f", celsius_to_f("temp_c"))
# Series -> scalar (aggregation pandas UDF)
@F.pandas_udf(DoubleType())
def gmean(v: pd.Series) -> float:
return v.prod() ** (1 / len(v))
df.groupBy("region").agg(gmean("ratio").alias("geo_mean"))
# --- 3. mapInPandas — iterate partitions as pandas DataFrames ---
# Great for: per-file parsing, calling libs (scipy/h5py), custom IO
map_schema = StructType([
StructField("id", LongType()),
StructField("feature", DoubleType()),
])
def transform(iterator):
for pdf in iterator: # each pdf is a chunk of the partition
pdf["feature"] = pdf["raw"] * 2.0
yield pdf[["id", "feature"]]
df.mapInPandas(transform, schema=map_schema)
# --- 4. applyInPandas — grouped map, full group as one pandas DataFrame ---
# Great for: per-group model fit, interpolation, ranking in pandas
fit_schema = StructType([
StructField("user_id", LongType()),
StructField("slope", DoubleType()),
])
def fit_trend(key, pdf):
import numpy as np
x = np.arange(len(pdf))
slope = np.polyfit(x, pdf["amount"], 1)[0] if len(pdf) > 1 else 0.0
return pd.DataFrame({"user_id": [key[0]], "slope": [slope]})
df.groupBy("user_id").applyInPandas(fit_trend, schema=fit_schema)
# Notes:
# - spark.sql.execution.arrow.pyspark.enabled = true (default on DBR)
# - applyInPandas loads the WHOLE group into memory — watch skewed groups
# - Register a UDF for SQL: spark.udf.register("normalize", normalize)
# =================================================================
# STRUCTURED STREAMING
# =================================================================
# --- Auto Loader — incremental file ingestion with schema inference ---
stream = (spark.readStream
.format("cloudFiles")
.option("cloudFiles.format", "json")
.option("cloudFiles.schemaLocation", "/Volumes/cat/sch/_schema/events")
.option("cloudFiles.inferColumnTypes", "true")
.option("cloudFiles.maxFilesPerTrigger", 100) # backpressure / rate limit
.load("/Volumes/cat/sch/landing/events/"))
# --- Watermarked windowed aggregation ---
agg = (stream
.withWatermark("event_ts", "10 minutes")
.groupBy(F.window("event_ts", "5 minutes"), "region")
.agg(F.sum("amount").alias("total"),
F.approx_count_distinct("user_id").alias("users")))
# --- Write — append to Delta with checkpoint (always set one!) ---
(agg.writeStream
.format("delta")
.outputMode("append")
.option("checkpointLocation", "/Volumes/cat/sch/_chk/agg")
.trigger(availableNow=True) # batch-like: drain then stop
.toTable("catalog.schema.events_5min"))
# --- foreachBatch — UPSERT (MERGE) per micro-batch ---
def upsert_to_delta(batch_df, batch_id):
deduped = (batch_df # dedup within the batch first
.withColumn("rn", F.row_number().over(
Window.partitionBy("id").orderBy(F.col("event_ts").desc())))
.filter("rn = 1").drop("rn"))
(DeltaTable.forName(spark, "catalog.schema.target").alias("t")
.merge(deduped.alias("s"), "t.id = s.id")
.whenMatchedUpdateAll()
.whenNotMatchedInsertAll()
.execute())
(stream.writeStream
.foreachBatch(upsert_to_delta)
.option("checkpointLocation", "/Volumes/cat/sch/_chk/upsert")
.trigger(processingTime="1 minute")
.start())
# Trigger modes:
# .trigger(processingTime="30 seconds") # fixed micro-batch interval
# .trigger(availableNow=True) # process all available, then stop
# (Spark 4.1 adds real-time mode — see post 002)
# Monitoring:
# q.lastProgress # latency, input/processed rows per second
# q.status # is it actively processing?
# spark.streams.active # list all running queries