Skip to content

Commit 6c2454f

Browse files
authored
feat(python): support batch mode (#10)
1 parent 59edaf7 commit 6c2454f

File tree

2 files changed

+49
-4
lines changed

2 files changed

+49
-4
lines changed

python/databend_udf/udf.py

+31-4
Original file line numberDiff line numberDiff line change
@@ -59,9 +59,17 @@ class ScalarFunction(UserDefinedFunction):
5959
_io_threads: Optional[int]
6060
_executor: Optional[ThreadPoolExecutor]
6161
_skip_null: bool
62+
_batch_mode: bool
6263

6364
def __init__(
64-
self, func, input_types, result_type, name=None, io_threads=None, skip_null=None
65+
self,
66+
func,
67+
input_types,
68+
result_type,
69+
name=None,
70+
io_threads=None,
71+
skip_null=None,
72+
batch_mode=False,
6573
):
6674
self._func = func
6775
self._input_schema = pa.schema(
@@ -78,6 +86,7 @@ def __init__(
7886
func.__name__ if hasattr(func, "__name__") else func.__class__.__name__
7987
)
8088
self._io_threads = io_threads
89+
self._batch_mode = batch_mode
8190
self._executor = (
8291
ThreadPoolExecutor(max_workers=self._io_threads)
8392
if self._io_threads is not None
@@ -98,7 +107,11 @@ def eval_batch(self, batch: pa.RecordBatch) -> Iterator[pa.RecordBatch]:
98107
_input_process_func(_list_field(field))(array)
99108
for array, field in zip(inputs, self._input_schema)
100109
]
101-
if self._executor is not None:
110+
111+
# evaluate the function for each row
112+
if self._batch_mode:
113+
column = self._func(*inputs)
114+
elif self._executor is not None:
102115
# concurrently evaluate the function for each row
103116
if self._skip_null:
104117
tasks = []
@@ -113,7 +126,6 @@ def eval_batch(self, batch: pa.RecordBatch) -> Iterator[pa.RecordBatch]:
113126
]
114127
column = [future.result() for future in tasks]
115128
else:
116-
# evaluate the function for each row
117129
if self._skip_null:
118130
column = []
119131
for row in range(batch.num_rows):
@@ -140,6 +152,7 @@ def udf(
140152
name: Optional[str] = None,
141153
io_threads: Optional[int] = 32,
142154
skip_null: Optional[bool] = False,
155+
batch_mode: Optional[bool] = False,
143156
) -> Callable:
144157
"""
145158
Annotation for creating a user-defined scalar function.
@@ -153,6 +166,7 @@ def udf(
153166
- skip_null: A boolean value specifying whether to skip NULL value. If it is set to True,
154167
NULL values will not be passed to the function,
155168
and the corresponding return value is set to NULL. Default to False.
169+
- batch_mode: A boolean value specifying whether to use batch mode. Default to False.
156170
157171
Example:
158172
```
@@ -170,6 +184,13 @@ def external_api(x):
170184
response = requests.get(my_endpoint + '?param=' + x)
171185
return response["data"]
172186
```
187+
188+
Batch mode example:
189+
```
190+
@udf(input_types=['INT', 'INT'], result_type='INT', batch_mode=True)
191+
def gcd(x, y):
192+
return [x_i if y_i == 0 else gcd(y_i, x_i % y_i) for x_i, y_i in zip(x, y)]
193+
```
173194
"""
174195

175196
if io_threads is not None and io_threads > 1:
@@ -180,10 +201,16 @@ def external_api(x):
180201
name,
181202
io_threads=io_threads,
182203
skip_null=skip_null,
204+
batch_mode=batch_mode,
183205
)
184206
else:
185207
return lambda f: ScalarFunction(
186-
f, input_types, result_type, name, skip_null=skip_null
208+
f,
209+
input_types,
210+
result_type,
211+
name,
212+
skip_null=skip_null,
213+
batch_mode=batch_mode,
187214
)
188215

189216

python/example/server.py

+18
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,23 @@ def gcd(x: int, y: int) -> int:
5555
return x
5656

5757

58+
@udf(
59+
name="gcd_batch",
60+
input_types=["INT", "INT"],
61+
result_type="INT",
62+
batch_mode=True,
63+
)
64+
def gcd_batch(x: list[int], y: list[int]) -> list[int]:
65+
def gcd_single(x_i, y_i):
66+
if x_i == None or y_i == None:
67+
return None
68+
while y_i != 0:
69+
(x_i, y_i) = (y_i, x_i % y_i)
70+
return x_i
71+
72+
return [gcd_single(x_i, y_i) for x_i, y_i in zip(x, y)]
73+
74+
5875
@udf(input_types=["VARCHAR", "VARCHAR", "VARCHAR"], result_type="VARCHAR")
5976
def split_and_join(s: str, split_s: str, join_s: str) -> str:
6077
return join_s.join(s.split(split_s))
@@ -303,6 +320,7 @@ def wait_concurrent(x):
303320
udf_server.add_function(binary_reverse)
304321
udf_server.add_function(bool_select)
305322
udf_server.add_function(gcd)
323+
udf_server.add_function(gcd_batch)
306324
udf_server.add_function(split_and_join)
307325
udf_server.add_function(decimal_div)
308326
udf_server.add_function(hex_to_dec)

0 commit comments

Comments
 (0)