@@ -8,6 +8,7 @@ mod dictionary_to;
8
8
mod primitive_to;
9
9
mod utf8_to;
10
10
11
+ use arrow:: bitmap:: MutableBitmap ;
11
12
pub use binary_to:: * ;
12
13
#[ cfg( feature = "dtype-decimal" ) ]
13
14
pub use binview_to:: binview_to_decimal;
@@ -237,6 +238,127 @@ pub(super) fn cast_list_to_fixed_size_list<O: Offset>(
237
238
. map_err ( |_| polars_err ! ( ComputeError : "not all elements have the specified width {size}" ) )
238
239
}
239
240
241
+ fn cast_list_uint8_to_binary < O : Offset > ( list : & ListArray < O > ) -> PolarsResult < BinaryViewArray > {
242
+ let mut views = Vec :: with_capacity ( list. len ( ) ) ;
243
+ let mut result_validity = MutableBitmap :: from_len_set ( list. len ( ) ) ;
244
+
245
+ let u8array: & PrimitiveArray < u8 > = list. values ( ) . as_any ( ) . downcast_ref ( ) . unwrap ( ) ;
246
+ let slice = u8array. values ( ) . as_slice ( ) ;
247
+ let mut cloned_buffers = vec ! [ u8array. values( ) . clone( ) ] ;
248
+ let mut buf_index = 0 ;
249
+ let mut previous_buf_lengths = 0 ;
250
+ let validity = list. validity ( ) ;
251
+ let internal_validity = list. values ( ) . validity ( ) ;
252
+ let offsets = list. offsets ( ) ;
253
+
254
+ let mut all_views_inline = true ;
255
+
256
+ // In a View for BinaryViewArray, both length and offset are u32.
257
+ #[ cfg( not( test) ) ]
258
+ const MAX_BUF_SIZE : usize = u32:: MAX as usize ;
259
+
260
+ // This allows us to test some invariants without using 4GB of RAM; see mod
261
+ // tests below.
262
+ #[ cfg( test) ]
263
+ const MAX_BUF_SIZE : usize = 15 ;
264
+
265
+ for index in 0 ..list. len ( ) {
266
+ // Check if there's a null instead of a list:
267
+ if let Some ( validity) = validity {
268
+ // SAFETY: We are generating indexes limited to < list.len().
269
+ debug_assert ! ( index < validity. len( ) ) ;
270
+ if unsafe { !validity. get_bit_unchecked ( index) } {
271
+ debug_assert ! ( index < result_validity. len( ) ) ;
272
+ unsafe {
273
+ result_validity. set_unchecked ( index, false ) ;
274
+ }
275
+ views. push ( View :: default ( ) ) ;
276
+ continue ;
277
+ }
278
+ }
279
+
280
+ // SAFETY: We are generating indexes limited to < list.len().
281
+ debug_assert ! ( index < offsets. len( ) ) ;
282
+ let ( start, end) = unsafe { offsets. start_end_unchecked ( index) } ;
283
+ let length = end - start;
284
+ polars_ensure ! (
285
+ length <= MAX_BUF_SIZE ,
286
+ InvalidOperation : format!( "when casting to BinaryView, list lengths must be <= {MAX_BUF_SIZE}" )
287
+ ) ;
288
+
289
+ // Check if the list contains nulls:
290
+ if let Some ( internal_validity) = internal_validity {
291
+ if internal_validity. null_count_range ( start, length) > 0 {
292
+ debug_assert ! ( index < result_validity. len( ) ) ;
293
+ unsafe {
294
+ result_validity. set_unchecked ( index, false ) ;
295
+ }
296
+ views. push ( View :: default ( ) ) ;
297
+ continue ;
298
+ }
299
+ }
300
+
301
+ if end - previous_buf_lengths > MAX_BUF_SIZE {
302
+ // View offsets must fit in u32 (or smaller value when running Rust
303
+ // tests), and we've determined the end of the next view will be
304
+ // past that.
305
+ buf_index += 1 ;
306
+ let ( previous, next) = cloned_buffers
307
+ . last ( )
308
+ . unwrap ( )
309
+ . split_at ( start - previous_buf_lengths) ;
310
+ debug_assert ! ( previous. len( ) <= MAX_BUF_SIZE ) ;
311
+ previous_buf_lengths += previous. len ( ) ;
312
+ * ( cloned_buffers. last_mut ( ) . unwrap ( ) ) = previous;
313
+ cloned_buffers. push ( next) ;
314
+ }
315
+ let view = View :: new_from_bytes (
316
+ & slice[ start..end] ,
317
+ buf_index,
318
+ ( start - previous_buf_lengths) as u32 ,
319
+ ) ;
320
+ if !view. is_inline ( ) {
321
+ all_views_inline = false ;
322
+ }
323
+ debug_assert_eq ! (
324
+ unsafe { view. get_slice_unchecked( & cloned_buffers) } ,
325
+ & slice[ start..end]
326
+ ) ;
327
+ views. push ( view) ;
328
+ }
329
+
330
+ // Optimization: don't actually need buffers if Views are all inline.
331
+ if all_views_inline {
332
+ cloned_buffers. clear ( ) ;
333
+ }
334
+
335
+ let result_buffers = cloned_buffers. into_boxed_slice ( ) . into ( ) ;
336
+ let result = if cfg ! ( debug_assertions) {
337
+ // A safer wrapper around new_unchecked_unknown_md; it shouldn't ever
338
+ // fail in practice.
339
+ BinaryViewArrayGeneric :: try_new (
340
+ ArrowDataType :: BinaryView ,
341
+ views. into ( ) ,
342
+ result_buffers,
343
+ result_validity. into ( ) ,
344
+ ) ?
345
+ } else {
346
+ unsafe {
347
+ BinaryViewArrayGeneric :: new_unchecked_unknown_md (
348
+ ArrowDataType :: BinaryView ,
349
+ views. into ( ) ,
350
+ result_buffers,
351
+ result_validity. into ( ) ,
352
+ // We could compute this ourselves, but we want to make this code
353
+ // match debug_assertions path as much as possible.
354
+ None ,
355
+ )
356
+ }
357
+ } ;
358
+
359
+ Ok ( result)
360
+ }
361
+
240
362
pub fn cast_default ( array : & dyn Array , to_type : & ArrowDataType ) -> PolarsResult < Box < dyn Array > > {
241
363
cast ( array, to_type, Default :: default ( ) )
242
364
}
@@ -258,6 +380,7 @@ pub fn cast_unchecked(array: &dyn Array, to_type: &ArrowDataType) -> PolarsResul
258
380
/// * Fixed Size List to List: the underlying data type is cast
259
381
/// * List to Fixed Size List: the offsets are checked for valid order, then the
260
382
/// underlying type is cast.
383
+ /// * List of UInt8 to Binary: the list of integers becomes binary data, nulls in the list means it becomes a null
261
384
/// * Struct to Struct: the underlying fields are cast.
262
385
/// * PrimitiveArray to List: a list array with 1 value per slot is created
263
386
/// * Date32 and Date64: precision lost when going to higher interval
@@ -267,7 +390,7 @@ pub fn cast_unchecked(array: &dyn Array, to_type: &ArrowDataType) -> PolarsResul
267
390
///
268
391
/// Unsupported Casts
269
392
/// * non-`StructArray` to `StructArray` or `StructArray` to non-`StructArray`
270
- /// * List to primitive
393
+ /// * List to primitive (other than UInt8)
271
394
/// * Utf8 to boolean
272
395
/// * Interval and duration
273
396
pub fn cast (
@@ -326,6 +449,14 @@ pub fn cast(
326
449
options,
327
450
)
328
451
. map ( |x| x. boxed ( ) ) ,
452
+ ( List ( field) , BinaryView ) if matches ! ( field. dtype( ) , UInt8 ) => {
453
+ cast_list_uint8_to_binary :: < i32 > ( array. as_any ( ) . downcast_ref ( ) . unwrap ( ) )
454
+ . map ( |arr| arr. boxed ( ) )
455
+ } ,
456
+ ( LargeList ( field) , BinaryView ) if matches ! ( field. dtype( ) , UInt8 ) => {
457
+ cast_list_uint8_to_binary :: < i64 > ( array. as_any ( ) . downcast_ref ( ) . unwrap ( ) )
458
+ . map ( |arr| arr. boxed ( ) )
459
+ } ,
329
460
( BinaryView , _) => match to_type {
330
461
Utf8View => array
331
462
. as_any ( )
@@ -853,3 +984,114 @@ fn from_to_binview(
853
984
} ;
854
985
Ok ( binview)
855
986
}
987
+
988
+ #[ cfg( test) ]
989
+ mod tests {
990
+ use arrow:: offset:: OffsetsBuffer ;
991
+ use polars_error:: PolarsError ;
992
+
993
+ use super :: * ;
994
+
995
+ /// When cfg(test), offsets for ``View``s generated by
996
+ /// cast_list_uint8_to_binary() are limited to max value of 3, so buffers
997
+ /// need to be split aggressively.
998
+ #[ test]
999
+ fn cast_list_uint8_to_binary_across_buffer_max_size ( ) {
1000
+ let dtype =
1001
+ ArrowDataType :: List ( Box :: new ( Field :: new ( "" . into ( ) , ArrowDataType :: UInt8 , true ) ) ) ;
1002
+ let values = PrimitiveArray :: from_slice ( ( 0u8 ..20 ) . collect :: < Vec < _ > > ( ) ) . boxed ( ) ;
1003
+ let list_u8 = ListArray :: try_new (
1004
+ dtype,
1005
+ unsafe { OffsetsBuffer :: new_unchecked ( vec ! [ 0 , 13 , 18 , 20 ] . into ( ) ) } ,
1006
+ values,
1007
+ None ,
1008
+ )
1009
+ . unwrap ( ) ;
1010
+
1011
+ let binary = cast (
1012
+ & list_u8,
1013
+ & ArrowDataType :: BinaryView ,
1014
+ CastOptionsImpl :: default ( ) ,
1015
+ )
1016
+ . unwrap ( ) ;
1017
+ let binary_array: & BinaryViewArray = binary. as_ref ( ) . as_any ( ) . downcast_ref ( ) . unwrap ( ) ;
1018
+ assert_eq ! (
1019
+ binary_array
1020
+ . values_iter( )
1021
+ . map( |s| s. to_vec( ) )
1022
+ . collect:: <Vec <Vec <u8 >>>( ) ,
1023
+ vec![
1024
+ vec![ 0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 , 11 , 12 ] ,
1025
+ vec![ 13 , 14 , 15 , 16 , 17 ] ,
1026
+ vec![ 18 , 19 ]
1027
+ ]
1028
+ ) ;
1029
+ // max offset of 15 so we need to split:
1030
+ assert_eq ! (
1031
+ binary_array
1032
+ . data_buffers( )
1033
+ . iter( )
1034
+ . map( |buf| buf. len( ) )
1035
+ . collect:: <Vec <_>>( ) ,
1036
+ vec![ 13 , 7 ]
1037
+ ) ;
1038
+ }
1039
+
1040
+ /// Arrow spec requires views to fit in a single buffer. When cfg(test),
1041
+ /// buffers generated by cast_list_uint8_to_binary are of size 15 or
1042
+ /// smaller, so a list of size 16 should cause an error.
1043
+ #[ test]
1044
+ fn cast_list_uint8_to_binary_errors_too_large_list ( ) {
1045
+ let values = PrimitiveArray :: from_slice ( vec ! [ 0u8 ; 16 ] ) ;
1046
+ let dtype =
1047
+ ArrowDataType :: List ( Box :: new ( Field :: new ( "" . into ( ) , ArrowDataType :: UInt8 , true ) ) ) ;
1048
+ let list_u8 = ListArray :: new (
1049
+ dtype,
1050
+ OffsetsBuffer :: one_with_length ( 16 ) ,
1051
+ values. boxed ( ) ,
1052
+ None ,
1053
+ ) ;
1054
+
1055
+ let err = cast (
1056
+ & list_u8,
1057
+ & ArrowDataType :: BinaryView ,
1058
+ CastOptionsImpl :: default ( ) ,
1059
+ )
1060
+ . unwrap_err ( ) ;
1061
+ assert ! ( matches!(
1062
+ err,
1063
+ PolarsError :: InvalidOperation ( msg)
1064
+ if msg. as_ref( ) == "when casting to BinaryView, list lengths must be <= 15"
1065
+ ) ) ;
1066
+ }
1067
+
1068
+ /// When all views are <=12, cast_list_uint8_to_binary drops buffers in the
1069
+ /// result because all views are inline.
1070
+ #[ test]
1071
+ fn cast_list_uint8_to_binary_drops_small_buffers ( ) {
1072
+ let values = PrimitiveArray :: from_slice ( vec ! [ 10u8 ; 12 ] ) ;
1073
+ let dtype =
1074
+ ArrowDataType :: List ( Box :: new ( Field :: new ( "" . into ( ) , ArrowDataType :: UInt8 , true ) ) ) ;
1075
+ let list_u8 = ListArray :: new (
1076
+ dtype,
1077
+ OffsetsBuffer :: one_with_length ( 12 ) ,
1078
+ values. boxed ( ) ,
1079
+ None ,
1080
+ ) ;
1081
+ let binary = cast (
1082
+ & list_u8,
1083
+ & ArrowDataType :: BinaryView ,
1084
+ CastOptionsImpl :: default ( ) ,
1085
+ )
1086
+ . unwrap ( ) ;
1087
+ let binary_array: & BinaryViewArray = binary. as_ref ( ) . as_any ( ) . downcast_ref ( ) . unwrap ( ) ;
1088
+ assert ! ( binary_array. data_buffers( ) . is_empty( ) ) ;
1089
+ assert_eq ! (
1090
+ binary_array
1091
+ . values_iter( )
1092
+ . map( |s| s. to_vec( ) )
1093
+ . collect:: <Vec <Vec <u8 >>>( ) ,
1094
+ vec![ vec![ 10 , 10 , 10 , 10 , 10 , 10 , 10 , 10 , 10 , 10 , 10 , 10 ] , ]
1095
+ ) ;
1096
+ }
1097
+ }
0 commit comments