@@ -59,9 +59,17 @@ class ScalarFunction(UserDefinedFunction):
59
59
_io_threads : Optional [int ]
60
60
_executor : Optional [ThreadPoolExecutor ]
61
61
_skip_null : bool
62
+ _batch_mode : bool
62
63
63
64
def __init__ (
64
- self , func , input_types , result_type , name = None , io_threads = None , skip_null = None
65
+ self ,
66
+ func ,
67
+ input_types ,
68
+ result_type ,
69
+ name = None ,
70
+ io_threads = None ,
71
+ skip_null = None ,
72
+ batch_mode = False ,
65
73
):
66
74
self ._func = func
67
75
self ._input_schema = pa .schema (
@@ -78,6 +86,7 @@ def __init__(
78
86
func .__name__ if hasattr (func , "__name__" ) else func .__class__ .__name__
79
87
)
80
88
self ._io_threads = io_threads
89
+ self ._batch_mode = batch_mode
81
90
self ._executor = (
82
91
ThreadPoolExecutor (max_workers = self ._io_threads )
83
92
if self ._io_threads is not None
@@ -98,7 +107,11 @@ def eval_batch(self, batch: pa.RecordBatch) -> Iterator[pa.RecordBatch]:
98
107
_input_process_func (_list_field (field ))(array )
99
108
for array , field in zip (inputs , self ._input_schema )
100
109
]
101
- if self ._executor is not None :
110
+
111
+ # evaluate the function for each row
112
+ if self ._batch_mode :
113
+ column = self ._func (* inputs )
114
+ elif self ._executor is not None :
102
115
# concurrently evaluate the function for each row
103
116
if self ._skip_null :
104
117
tasks = []
@@ -113,7 +126,6 @@ def eval_batch(self, batch: pa.RecordBatch) -> Iterator[pa.RecordBatch]:
113
126
]
114
127
column = [future .result () for future in tasks ]
115
128
else :
116
- # evaluate the function for each row
117
129
if self ._skip_null :
118
130
column = []
119
131
for row in range (batch .num_rows ):
@@ -140,6 +152,7 @@ def udf(
140
152
name : Optional [str ] = None ,
141
153
io_threads : Optional [int ] = 32 ,
142
154
skip_null : Optional [bool ] = False ,
155
+ batch_mode : Optional [bool ] = False ,
143
156
) -> Callable :
144
157
"""
145
158
Annotation for creating a user-defined scalar function.
@@ -153,6 +166,7 @@ def udf(
153
166
- skip_null: A boolean value specifying whether to skip NULL value. If it is set to True,
154
167
NULL values will not be passed to the function,
155
168
and the corresponding return value is set to NULL. Default to False.
169
+ - batch_mode: A boolean value specifying whether to use batch mode. Default to False.
156
170
157
171
Example:
158
172
```
@@ -170,6 +184,13 @@ def external_api(x):
170
184
response = requests.get(my_endpoint + '?param=' + x)
171
185
return response["data"]
172
186
```
187
+
188
+ Batch mode example:
189
+ ```
190
+ @udf(input_types=['INT', 'INT'], result_type='INT', batch_mode=True)
191
+ def gcd(x, y):
192
+ return [x_i if y_i == 0 else gcd(y_i, x_i % y_i) for x_i, y_i in zip(x, y)]
193
+ ```
173
194
"""
174
195
175
196
if io_threads is not None and io_threads > 1 :
@@ -180,10 +201,16 @@ def external_api(x):
180
201
name ,
181
202
io_threads = io_threads ,
182
203
skip_null = skip_null ,
204
+ batch_mode = batch_mode ,
183
205
)
184
206
else :
185
207
return lambda f : ScalarFunction (
186
- f , input_types , result_type , name , skip_null = skip_null
208
+ f ,
209
+ input_types ,
210
+ result_type ,
211
+ name ,
212
+ skip_null = skip_null ,
213
+ batch_mode = batch_mode ,
187
214
)
188
215
189
216
0 commit comments