Vectorized Python User-defined Table Functions (UDTFs)#
Spark 4.1 introduces the Vectorized Python user-defined table function (UDTF), a new type of user-defined table-valued function.
It can be used via the @arrow_udtf
decorator.
Unlike scalar functions that return a single result value from each call, each UDTF is invoked in
the FROM
clause of a query and returns an entire table as output.
Unlike the traditional Python UDTF that evaluates row by row, the Vectorized Python UDTF lets you directly operate on top of Apache Arrow arrays and column batches.
This allows you to leverage vectorized operations and improve the performance of your UDTF.
Vectorized Python UDTF Interface#
class NameYourArrowPythonUDTF:
def __init__(self) -> None:
"""
Initializes the user-defined table function (UDTF). This is optional.
This method serves as the default constructor and is called once when the
UDTF is instantiated on the executor side.
Any class fields assigned in this method will be available for subsequent
calls to the `eval`, `terminate` and `cleanup` methods.
Notes
-----
- You cannot create or reference the Spark session within the UDTF. Any
attempt to do so will result in a serialization error.
"""
...
def eval(self, *args: Any) -> Iterator[pa.RecordBatch | pa.Table]:
"""
Evaluates the function using the given input arguments.
This method is required and must be implemented.
Argument Mapping:
- Each provided scalar expression maps to exactly one value in the
`*args` list with type `pa.Array`.
- Each provided table argument maps to a `pa.RecordBatch` object containing
the columns in the order they appear in the provided input table,
and with the names computed by the query analyzer.
This method is called on every batch of input rows, and can produce zero or more
output pyarrow record batches or pyarrow tables. Each element in the output tuple
corresponds to one column specified in the return type of the UDTF.
Parameters
----------
*args : Any
Arbitrary positional arguments representing the input to the UDTF.
Yields
------
iterator
An iterator of `pa.RecordBatch` or `pa.Table` objects representing a batch of rows
in the UDTF result table. Yield as many times as needed to produce multiple batches.
Notes
-----
- UDTFs can instead accept keyword arguments during the function call if needed.
- The `eval` method can raise a `SkipRestOfInputTableException` to indicate that the
UDTF wants to skip consuming all remaining rows from the current partition of the
input table. This will cause the UDTF to proceed directly to the `terminate` method.
- The `eval` method can raise any other exception to indicate that the UDTF should be
aborted entirely. This will cause the UDTF to skip the `terminate` method and proceed
directly to the `cleanup` method, and then the exception will be propagated to the
query processor causing the invoking query to fail.
Examples
--------
This `eval` method takes a table argument and returns an arrow record batch for each input batch.
>>> def eval(self, batch: pa.RecordBatch):
... yield batch
This `eval` method takes a table argument and returns a pyarrow table for each input batch.
>>> def eval(self, batch: pa.RecordBatch):
... yield pa.table({"x": batch.column(0), "y": batch.column(1)})
This `eval` method takes both table and scalar arguments and returns a pyarrow table for each input batch.
>>> def eval(self, batch: pa.RecordBatch, x: pa.Array):
... yield pa.table({"x": x, "y": batch.column(0)})
"""
...
def terminate(self) -> Iterator[pa.RecordBatch | pa.Table]:
"""
Called when the UDTF has successfully processed all input rows.
This method is optional to implement and is useful for performing any
finalization operations after the UDTF has finished processing
all rows. It can also be used to yield additional rows if needed.
Table functions that consume all rows in the entire input partition
and then compute and return the entire output table can do so from
this method as well (please be mindful of memory usage when doing
this).
If any exceptions occur during input row processing, this method
won't be called.
Yields
------
iterator
An iterator of `pa.RecordBatch` or `pa.Table` objects representing a batch of rows
in the UDTF result table. Yield as many times as needed to produce multiple batches.
Examples
--------
>>> def terminate(self) -> Iterator[pa.RecordBatch | pa.Table]:
>>> yield pa.table({"x": pa.array([1, 2, 3])})
"""
...
def cleanup(self) -> None:
"""
Invoked after the UDTF completes processing input rows.
This method is optional to implement and is useful for final cleanup
regardless of whether the UDTF processed all input rows successfully
or was aborted due to exceptions.
Examples
--------
>>> def cleanup(self) -> None:
>>> self.conn.close()
"""
...
Defining the Output Schema#
The return type of the UDTF defines the schema of the table it outputs.
You can specify it in the @arrow_udtf
decorator.
It must be either a StructType
:
@arrow_udtf(returnType=StructType().add("c1", StringType()).add("c2", IntegerType()))
class YourArrowPythonUDTF:
...
or a DDL string representing a struct type:
@arrow_udtf(returnType="c1 string, c2 int")
class YourArrowPythonUDTF:
...
Emitting Output Rows#
The eval and terminate methods then emit zero or more output batches conforming to this schema by
yielding pa.RecordBatch
or pa.Table
objects.
@arrow_udtf(returnType="c1 int, c2 int")
class YourArrowPythonUDTF:
def eval(self, batch: pa.RecordBatch):
yield pa.table({"c1": batch.column(0), "c2": batch.column(1)})
You can also yield multiple pyarrow tables in the eval method.
@arrow_udtf(returnType="c1 int")
class YourArrowPythonUDTF:
def eval(self, batch: pa.RecordBatch):
yield pa.table({"c1": batch.column(0)})
yield pa.table({"c1": batch.column(1)})
You can also yield multiple pyarrow record batches in the eval method.
@arrow_udtf(returnType="c1 int")
class YourArrowPythonUDTF:
def eval(self, batch: pa.RecordBatch):
new_batch = pa.record_batch(
{"c1": batch.column(0).slice(0, len(batch) // 2)})
yield new_batch
Usage Examples#
Here’s how to use these UDTFs in DataFrame:
import pyarrow as pa
from pyspark.sql.functions import arrow_udtf
@arrow_udtf(returnType="c1 string")
class MyArrowPythonUDTF:
def eval(self, batch: pa.RecordBatch):
yield pa.table({"c1": batch.column("value")})
df = spark.range(10).selectExpr("id", "cast(id as string) as value")
MyArrowPythonUDTF(df.asTable()).show()
# Result:
# +---+
# | c1|
# +---+
# | 0|
# | 1|
# | 2|
# | 3|
# | 4|
# | 5|
# | 6|
# | 7|
# | 8|
# | 9|
# +---+
# Register the UDTF
spark.udtf.register("my_arrow_udtf", MyArrowPythonUDTF)
# Use in SQL queries
df = spark.sql("""
SELECT * FROM my_arrow_udtf(TABLE(SELECT id, cast(id as string) as value FROM range(10)))
""")
TABLE Argument#
Arrow UDTFs can take a TABLE argument. When your UDTF receives a TABLE argument,
its eval
method is called with a pyarrow.RecordBatch
containing the input
table’s columns, and any additional scalar/struct expressions are passed as
pyarrow.Array
values.
Key points:
- The TABLE argument is a single pa.RecordBatch
; access columns by name or index.
- Scalar arguments (including structs) are pa.Array
values, not RecordBatch
.
- Named and positional arguments are both supported in SQL.
Example (DataFrame API):
import pyarrow as pa
import pyarrow.compute as pc
from typing import Iterator, Optional
from pyspark.sql.functions import arrow_udtf, SkipRestOfInputTableException
@arrow_udtf(returnType="value int")
class EchoTable:
def eval(self, batch: pa.RecordBatch) -> Iterator[pa.Table]:
# Return the input column named "value" as-is
yield pa.table({"value": batch.column("value")})
df = spark.range(5).selectExpr("id as value")
EchoTable(df.asTable()).show()
# Result:
# +-----+
# |value|
# +-----+
# | 0|
# | 1|
# | 2|
# | 3|
# | 4|
# +-----+
Example (SQL): TABLE plus a scalar threshold
import pyarrow as pa
import pyarrow.compute as pc
from typing import Iterator
from pyspark.sql.functions import arrow_udtf
# Keep rows with value > threshold; works with SQL using TABLE + scalar argument
@arrow_udtf(returnType="partition_key int, value int")
class ThresholdFilter:
def eval(self, batch: pa.RecordBatch, threshold: pa.Array) -> Iterator[pa.Table]:
tbl = pa.table(batch)
thr = int(threshold.cast(pa.int64())[0].as_py())
mask = pc.greater(tbl["value"], thr)
yield tbl.filter(mask)
spark.udtf.register("threshold_filter", ThresholdFilter)
spark.createDataFrame([(1, 10), (1, 30), (2, 5)], "partition_key int, value int").createOrReplaceTempView("v")
spark.sql(
"""
SELECT *
FROM threshold_filter(
TABLE(v),
10
)
ORDER BY partition_key, value
"""
).show()
# Result:
# +-------------+-----+
# |partition_key|value|
# +-------------+-----+
# | 1| 30|
# +-------------+-----+
PARTITION BY and ORDER BY#
Arrow UDTFs support TABLE(...) PARTITION BY ... ORDER BY ...
. Think of it as
“process rows group by group, and in a specific order within each group”.
Semantics:
PARTITION BY groups rows by the given keys; your UDTF runs for each group independently.
ORDER BY controls the row order within each group as seen by
eval
.eval
may be called multiple times per group; accumulate state and typically emit the group’s result interminate
.
Example: Aggregation per key with terminate#
PARTITION BY is especially useful for per-group aggregation. eval
may be called
multiple times for the same group as rows arrive in batches, so keep running totals in
the UDTF instance and emit the final row in terminate
.
import pyarrow as pa
import pyarrow.compute as pc
from typing import Iterator
from pyspark.sql.functions import arrow_udtf
@arrow_udtf(returnType="user_id int, total_amount int, rows int")
class SumPerUser:
def __init__(self):
self._user = None
self._sum = 0
self._count = 0
def eval(self, batch: pa.RecordBatch) -> Iterator[pa.Table]:
tbl = pa.table(batch)
# All rows in this batch belong to the same user within a partition
self._user = pc.unique(tbl["user_id"]).to_pylist()[0]
self._sum += pc.sum(tbl["amount"]).as_py()
self._count += tbl.num_rows
return iter(()) # emit once in terminate
def terminate(self) -> Iterator[pa.Table]:
if self._user is not None:
yield pa.table({
"user_id": pa.array([self._user], pa.int32()),
"total_amount": pa.array([self._sum], pa.int32()),
"rows": pa.array([self._count], pa.int32()),
})
spark.udtf.register("sum_per_user", SumPerUser)
spark.createDataFrame(
[(1, 10), (2, 5), (1, 20), (2, 15), (3, 7)],
"user_id int, amount int",
).createOrReplaceTempView("purchases")
spark.sql(
"""
SELECT *
FROM sum_per_user(
TABLE(purchases)
PARTITION BY user_id
)
ORDER BY user_id
"""
).show()
# Result:
# +-------+------------+----+
# |user_id|total_amount|rows|
# +-------+------------+----+
# | 1| 30| 2|
# | 2| 20| 2|
# | 3| 7| 1|
# +-------+------------+----+
Why terminate? eval
may run multiple times per group if the input is split into
several batches. Emitting the aggregated row in terminate
guarantees exactly one
output row per group after all its rows have been processed.
Example: Top reviews per product using ORDER BY#
import pyarrow as pa
import pyarrow.compute as pc
from typing import Iterator, Optional
from pyspark.sql.functions import arrow_udtf, SkipRestOfInputTableException
@arrow_udtf(returnType="product_id int, review_id int, rating int, review string")
class TopReviewsPerProduct:
TOP_K = 3
def __init__(self):
self._product = None
self._seen = 0
self._batches: list[pa.Table] = []
self._top: Optional[pa.Table] = None
def eval(self, batch: pa.RecordBatch) -> Iterator[pa.Table]:
tbl = pa.table(batch)
if tbl.num_rows == 0:
return iter(())
products = pc.unique(tbl["product_id"]).to_pylist()
assert len(products) == 1, f"Expected one product per batch, saw {products}"
product = products[0]
if self._product is None:
self._product = product
else:
assert self._product == product, f"Mixed products {self._product} and {product}"
self._batches.append(tbl)
self._seen += tbl.num_rows
if self._seen >= self.TOP_K and self._top is None:
combined = pa.concat_tables(self._batches)
self._top = combined.slice(0, self.TOP_K)
raise SkipRestOfInputTableException(
f"Collected top {self.TOP_K} reviews for product {self._product}"
)
return iter(())
def terminate(self) -> Iterator[pa.Table]:
if self._product is None:
return iter(())
if self._top is None:
combined = pa.concat_tables(self._batches) if self._batches else pa.table({})
limit = min(self.TOP_K, self._seen)
self._top = combined.slice(0, limit)
yield self._top
spark.udtf.register("top_reviews_per_product", TopReviewsPerProduct)
spark.createDataFrame(
[
(101, 1, 5, "Amazing battery life"),
(101, 2, 5, "Still great after a month"),
(101, 3, 4, "Solid build"),
(101, 4, 3, "Average sound"),
(202, 5, 5, "My go-to lens"),
(202, 6, 4, "Sharp and bright"),
(202, 7, 4, "Great value"),
],
"product_id int, review_id int, rating int, review string",
).createOrReplaceTempView("reviews")
spark.sql(
"""
SELECT *
FROM top_reviews_per_product(
TABLE(reviews)
PARTITION BY (product_id)
ORDER BY (rating DESC, review_id)
)
ORDER BY product_id, rating DESC, review_id
"""
).show()
# Result:
# +----------+---------+------+--------------------------+
# |product_id|review_id|rating|review |
# +----------+---------+------+--------------------------+
# | 101| 1| 5|Amazing battery life |
# | 101| 2| 5|Still great after a month |
# | 101| 3| 4|Solid build |
# | 202| 5| 5|My go-to lens |
# | 202| 6| 4|Sharp and bright |
# | 202| 7| 4|Great value |
# +----------+---------+------+--------------------------+
Best Practices#
Stream work from
eval()
when possible. Yielding onepa.Table
per Arrow batch keeps memory bounded and shortens feedback loops; reserveterminate()
for true per-partition operations.Keep per-partition state tiny and reset it promptly. If you only need the first N rows, raise
SkipRestOfInputTableException
after collecting them so Spark skips the rest of the partition.Guard external calls with short timeouts and operate on the current batch instead of deferring to
terminate
; this avoids giant buffers and keeps retries narrow.
When to use Arrow UDTFs vs Other UDTFs#
Prefer
arrow_udtf
when the logic is naturally vectorised, you can stay within Python, and the input/output schema is Arrow-friendly. You gain batch-friendly performance and native interoperability with PySpark DataFrames.Stick with the classic (row-based) Python UDTF when you only need simple per-row expansion, or when your logic depends on Python objects that Arrow cannot represent cleanly.
Use SQL UDTFs if the functionality is performance critical and the logic can be represented in SQL.
More Examples#
Example: Simple anomaly detection per device#
Compute simple per-device stats and return them from terminate
; this pattern is
useful for anomaly detection workflows that first summarize distributions by key.
import pyarrow as pa
import pyarrow.compute as pc
from typing import Iterator
from pyspark.sql.functions import arrow_udtf
@arrow_udtf(returnType="device_id int, count int, mean double, stddev double, max_value int")
class DeviceStats:
def __init__(self):
self._device = None
self._count = 0
self._sum = 0.0
self._sumsq = 0.0
self._max = None
def eval(self, batch: pa.RecordBatch) -> Iterator[pa.Table]:
tbl = pa.table(batch)
self._device = pc.unique(tbl["device_id"]).to_pylist()[0]
vals = tbl["reading"].cast(pa.float64())
self._count += len(vals)
self._sum += pc.sum(vals).as_py() or 0.0
self._sumsq += pc.sum(pc.multiply(vals, vals)).as_py() or 0.0
cur_max = pc.max(vals).as_py()
self._max = cur_max if self._max is None else max(self._max, cur_max)
return iter(())
def terminate(self) -> Iterator[pa.Table]:
if self._device is not None and self._count > 0:
mean = self._sum / self._count
var = max(self._sumsq / self._count - mean * mean, 0.0)
std = var ** 0.5
# Round to 2 decimal places for display
mean_rounded = round(mean, 2)
std_rounded = round(std, 2)
yield pa.table({
"device_id": pa.array([self._device], pa.int32()),
"count": pa.array([self._count], pa.int32()),
"mean": pa.array([mean_rounded], pa.float64()),
"stddev": pa.array([std_rounded], pa.float64()),
"max_value": pa.array([int(self._max)], pa.int32()),
})
spark.udtf.register("device_stats", DeviceStats)
spark.createDataFrame(
[(1, 10), (1, 12), (1, 100), (2, 5), (2, 7)],
"device_id int, reading int",
).createOrReplaceTempView("readings")
spark.sql(
"""
SELECT *
FROM device_stats(
TABLE(readings)
PARTITION BY device_id
)
ORDER BY device_id
"""
).show()
# Result:
# +---------+-----+-----+------+---------+
# |device_id|count| mean|stddev|max_value|
# +---------+-----+-----+------+---------+
# | 1| 3|40.67| 41.96| 100|
# | 2| 2| 6.0| 1.0| 7|
# +---------+-----+-----+------+---------+
Example: Arrow UDTFs as RDD map-style transforms#
Arrow UDTFs can replace many RDD.map
/flatMap
-style transforms with better
performance and first-class SQL integration. Instead of mapping row-by-row in Python,
you work on Arrow batches and return a table.
Example: tokenize text into words (flatMap-like)
import pyarrow as pa
import pyarrow.compute as pc
from typing import Iterator
from pyspark.sql.functions import arrow_udtf
@arrow_udtf(returnType="doc_id int, word string")
class Tokenize:
def eval(self, batch: pa.RecordBatch) -> Iterator[pa.Table]:
tbl = pa.table(batch)
# Split on whitespace; build flat arrays for (doc_id, word)
doc_ids: list[int] = []
words: list[str] = []
for doc_id, text in zip(tbl["doc_id"].to_pylist(), tbl["text"].to_pylist()):
for w in (text or "").split():
doc_ids.append(doc_id)
words.append(w)
if doc_ids:
yield pa.table({"doc_id": pa.array(doc_ids, pa.int32()), "word": pa.array(words)})
spark.udtf.register("tokenize", Tokenize)
spark.createDataFrame([(1, "spark is fast"), (2, "arrow udtf")], "doc_id int, text string").createOrReplaceTempView("docs")
spark.sql("SELECT * FROM tokenize(TABLE(docs)) ORDER BY doc_id, word").show()
# Result:
# +------+-----+
# |doc_id| word|
# +------+-----+
# | 1| fast|
# | 1| is|
# | 1|spark|
# | 2|arrow|
# | 2| udtf|
# +------+-----+