Skip to content

Commit 6ee3df0

Browse files
feat(python): Allow casting List<UInt8> to Binary (#22611)
Co-authored-by: Itamar Turner-Trauring <[email protected]>
1 parent fcdf7f2 commit 6ee3df0

File tree

3 files changed

+310
-1
lines changed

3 files changed

+310
-1
lines changed

crates/polars-compute/src/cast/mod.rs

+243-1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ mod dictionary_to;
88
mod primitive_to;
99
mod utf8_to;
1010

11+
use arrow::bitmap::MutableBitmap;
1112
pub use binary_to::*;
1213
#[cfg(feature = "dtype-decimal")]
1314
pub use binview_to::binview_to_decimal;
@@ -237,6 +238,127 @@ pub(super) fn cast_list_to_fixed_size_list<O: Offset>(
237238
.map_err(|_| polars_err!(ComputeError: "not all elements have the specified width {size}"))
238239
}
239240

241+
fn cast_list_uint8_to_binary<O: Offset>(list: &ListArray<O>) -> PolarsResult<BinaryViewArray> {
242+
let mut views = Vec::with_capacity(list.len());
243+
let mut result_validity = MutableBitmap::from_len_set(list.len());
244+
245+
let u8array: &PrimitiveArray<u8> = list.values().as_any().downcast_ref().unwrap();
246+
let slice = u8array.values().as_slice();
247+
let mut cloned_buffers = vec![u8array.values().clone()];
248+
let mut buf_index = 0;
249+
let mut previous_buf_lengths = 0;
250+
let validity = list.validity();
251+
let internal_validity = list.values().validity();
252+
let offsets = list.offsets();
253+
254+
let mut all_views_inline = true;
255+
256+
// In a View for BinaryViewArray, both length and offset are u32.
257+
#[cfg(not(test))]
258+
const MAX_BUF_SIZE: usize = u32::MAX as usize;
259+
260+
// This allows us to test some invariants without using 4GB of RAM; see mod
261+
// tests below.
262+
#[cfg(test)]
263+
const MAX_BUF_SIZE: usize = 15;
264+
265+
for index in 0..list.len() {
266+
// Check if there's a null instead of a list:
267+
if let Some(validity) = validity {
268+
// SAFETY: We are generating indexes limited to < list.len().
269+
debug_assert!(index < validity.len());
270+
if unsafe { !validity.get_bit_unchecked(index) } {
271+
debug_assert!(index < result_validity.len());
272+
unsafe {
273+
result_validity.set_unchecked(index, false);
274+
}
275+
views.push(View::default());
276+
continue;
277+
}
278+
}
279+
280+
// SAFETY: We are generating indexes limited to < list.len().
281+
debug_assert!(index < offsets.len());
282+
let (start, end) = unsafe { offsets.start_end_unchecked(index) };
283+
let length = end - start;
284+
polars_ensure!(
285+
length <= MAX_BUF_SIZE,
286+
InvalidOperation: format!("when casting to BinaryView, list lengths must be <= {MAX_BUF_SIZE}")
287+
);
288+
289+
// Check if the list contains nulls:
290+
if let Some(internal_validity) = internal_validity {
291+
if internal_validity.null_count_range(start, length) > 0 {
292+
debug_assert!(index < result_validity.len());
293+
unsafe {
294+
result_validity.set_unchecked(index, false);
295+
}
296+
views.push(View::default());
297+
continue;
298+
}
299+
}
300+
301+
if end - previous_buf_lengths > MAX_BUF_SIZE {
302+
// View offsets must fit in u32 (or smaller value when running Rust
303+
// tests), and we've determined the end of the next view will be
304+
// past that.
305+
buf_index += 1;
306+
let (previous, next) = cloned_buffers
307+
.last()
308+
.unwrap()
309+
.split_at(start - previous_buf_lengths);
310+
debug_assert!(previous.len() <= MAX_BUF_SIZE);
311+
previous_buf_lengths += previous.len();
312+
*(cloned_buffers.last_mut().unwrap()) = previous;
313+
cloned_buffers.push(next);
314+
}
315+
let view = View::new_from_bytes(
316+
&slice[start..end],
317+
buf_index,
318+
(start - previous_buf_lengths) as u32,
319+
);
320+
if !view.is_inline() {
321+
all_views_inline = false;
322+
}
323+
debug_assert_eq!(
324+
unsafe { view.get_slice_unchecked(&cloned_buffers) },
325+
&slice[start..end]
326+
);
327+
views.push(view);
328+
}
329+
330+
// Optimization: don't actually need buffers if Views are all inline.
331+
if all_views_inline {
332+
cloned_buffers.clear();
333+
}
334+
335+
let result_buffers = cloned_buffers.into_boxed_slice().into();
336+
let result = if cfg!(debug_assertions) {
337+
// A safer wrapper around new_unchecked_unknown_md; it shouldn't ever
338+
// fail in practice.
339+
BinaryViewArrayGeneric::try_new(
340+
ArrowDataType::BinaryView,
341+
views.into(),
342+
result_buffers,
343+
result_validity.into(),
344+
)?
345+
} else {
346+
unsafe {
347+
BinaryViewArrayGeneric::new_unchecked_unknown_md(
348+
ArrowDataType::BinaryView,
349+
views.into(),
350+
result_buffers,
351+
result_validity.into(),
352+
// We could compute this ourselves, but we want to make this code
353+
// match debug_assertions path as much as possible.
354+
None,
355+
)
356+
}
357+
};
358+
359+
Ok(result)
360+
}
361+
240362
pub fn cast_default(array: &dyn Array, to_type: &ArrowDataType) -> PolarsResult<Box<dyn Array>> {
241363
cast(array, to_type, Default::default())
242364
}
@@ -258,6 +380,7 @@ pub fn cast_unchecked(array: &dyn Array, to_type: &ArrowDataType) -> PolarsResul
258380
/// * Fixed Size List to List: the underlying data type is cast
259381
/// * List to Fixed Size List: the offsets are checked for valid order, then the
260382
/// underlying type is cast.
383+
/// * List of UInt8 to Binary: the list of integers becomes binary data, nulls in the list means it becomes a null
261384
/// * Struct to Struct: the underlying fields are cast.
262385
/// * PrimitiveArray to List: a list array with 1 value per slot is created
263386
/// * Date32 and Date64: precision lost when going to higher interval
@@ -267,7 +390,7 @@ pub fn cast_unchecked(array: &dyn Array, to_type: &ArrowDataType) -> PolarsResul
267390
///
268391
/// Unsupported Casts
269392
/// * non-`StructArray` to `StructArray` or `StructArray` to non-`StructArray`
270-
/// * List to primitive
393+
/// * List to primitive (other than UInt8)
271394
/// * Utf8 to boolean
272395
/// * Interval and duration
273396
pub fn cast(
@@ -326,6 +449,14 @@ pub fn cast(
326449
options,
327450
)
328451
.map(|x| x.boxed()),
452+
(List(field), BinaryView) if matches!(field.dtype(), UInt8) => {
453+
cast_list_uint8_to_binary::<i32>(array.as_any().downcast_ref().unwrap())
454+
.map(|arr| arr.boxed())
455+
},
456+
(LargeList(field), BinaryView) if matches!(field.dtype(), UInt8) => {
457+
cast_list_uint8_to_binary::<i64>(array.as_any().downcast_ref().unwrap())
458+
.map(|arr| arr.boxed())
459+
},
329460
(BinaryView, _) => match to_type {
330461
Utf8View => array
331462
.as_any()
@@ -853,3 +984,114 @@ fn from_to_binview(
853984
};
854985
Ok(binview)
855986
}
987+
988+
#[cfg(test)]
989+
mod tests {
990+
use arrow::offset::OffsetsBuffer;
991+
use polars_error::PolarsError;
992+
993+
use super::*;
994+
995+
/// When cfg(test), offsets for ``View``s generated by
996+
/// cast_list_uint8_to_binary() are limited to max value of 3, so buffers
997+
/// need to be split aggressively.
998+
#[test]
999+
fn cast_list_uint8_to_binary_across_buffer_max_size() {
1000+
let dtype =
1001+
ArrowDataType::List(Box::new(Field::new("".into(), ArrowDataType::UInt8, true)));
1002+
let values = PrimitiveArray::from_slice((0u8..20).collect::<Vec<_>>()).boxed();
1003+
let list_u8 = ListArray::try_new(
1004+
dtype,
1005+
unsafe { OffsetsBuffer::new_unchecked(vec![0, 13, 18, 20].into()) },
1006+
values,
1007+
None,
1008+
)
1009+
.unwrap();
1010+
1011+
let binary = cast(
1012+
&list_u8,
1013+
&ArrowDataType::BinaryView,
1014+
CastOptionsImpl::default(),
1015+
)
1016+
.unwrap();
1017+
let binary_array: &BinaryViewArray = binary.as_ref().as_any().downcast_ref().unwrap();
1018+
assert_eq!(
1019+
binary_array
1020+
.values_iter()
1021+
.map(|s| s.to_vec())
1022+
.collect::<Vec<Vec<u8>>>(),
1023+
vec![
1024+
vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
1025+
vec![13, 14, 15, 16, 17],
1026+
vec![18, 19]
1027+
]
1028+
);
1029+
// max offset of 15 so we need to split:
1030+
assert_eq!(
1031+
binary_array
1032+
.data_buffers()
1033+
.iter()
1034+
.map(|buf| buf.len())
1035+
.collect::<Vec<_>>(),
1036+
vec![13, 7]
1037+
);
1038+
}
1039+
1040+
/// Arrow spec requires views to fit in a single buffer. When cfg(test),
1041+
/// buffers generated by cast_list_uint8_to_binary are of size 15 or
1042+
/// smaller, so a list of size 16 should cause an error.
1043+
#[test]
1044+
fn cast_list_uint8_to_binary_errors_too_large_list() {
1045+
let values = PrimitiveArray::from_slice(vec![0u8; 16]);
1046+
let dtype =
1047+
ArrowDataType::List(Box::new(Field::new("".into(), ArrowDataType::UInt8, true)));
1048+
let list_u8 = ListArray::new(
1049+
dtype,
1050+
OffsetsBuffer::one_with_length(16),
1051+
values.boxed(),
1052+
None,
1053+
);
1054+
1055+
let err = cast(
1056+
&list_u8,
1057+
&ArrowDataType::BinaryView,
1058+
CastOptionsImpl::default(),
1059+
)
1060+
.unwrap_err();
1061+
assert!(matches!(
1062+
err,
1063+
PolarsError::InvalidOperation(msg)
1064+
if msg.as_ref() == "when casting to BinaryView, list lengths must be <= 15"
1065+
));
1066+
}
1067+
1068+
/// When all views are <=12, cast_list_uint8_to_binary drops buffers in the
1069+
/// result because all views are inline.
1070+
#[test]
1071+
fn cast_list_uint8_to_binary_drops_small_buffers() {
1072+
let values = PrimitiveArray::from_slice(vec![10u8; 12]);
1073+
let dtype =
1074+
ArrowDataType::List(Box::new(Field::new("".into(), ArrowDataType::UInt8, true)));
1075+
let list_u8 = ListArray::new(
1076+
dtype,
1077+
OffsetsBuffer::one_with_length(12),
1078+
values.boxed(),
1079+
None,
1080+
);
1081+
let binary = cast(
1082+
&list_u8,
1083+
&ArrowDataType::BinaryView,
1084+
CastOptionsImpl::default(),
1085+
)
1086+
.unwrap();
1087+
let binary_array: &BinaryViewArray = binary.as_ref().as_any().downcast_ref().unwrap();
1088+
assert!(binary_array.data_buffers().is_empty());
1089+
assert_eq!(
1090+
binary_array
1091+
.values_iter()
1092+
.map(|s| s.to_vec())
1093+
.collect::<Vec<Vec<u8>>>(),
1094+
vec![vec![10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10],]
1095+
);
1096+
}
1097+
}

crates/polars-core/src/chunked_array/cast.rs

+19
Original file line numberDiff line numberDiff line change
@@ -513,6 +513,25 @@ impl ChunkCast for ListChunked {
513513
))
514514
}
515515
},
516+
#[cfg(feature = "dtype-u8")]
517+
Binary => {
518+
polars_ensure!(
519+
matches!(self.inner_dtype(), UInt8),
520+
InvalidOperation: "cannot cast List type (inner: '{:?}', to: '{:?}')",
521+
self.inner_dtype(),
522+
dtype,
523+
);
524+
let chunks = cast_chunks(self.chunks(), &DataType::Binary, options)?;
525+
526+
// SAFETY: we just cast so the dtype matches.
527+
unsafe {
528+
Ok(Series::from_chunks_and_dtype_unchecked(
529+
self.name().clone(),
530+
chunks,
531+
&DataType::Binary,
532+
))
533+
}
534+
},
516535
_ => {
517536
polars_bail!(
518537
InvalidOperation: "cannot cast List type (inner: '{:?}', to: '{:?}')",

py-polars/tests/unit/operations/test_cast.py

+48
Original file line numberDiff line numberDiff line change
@@ -656,6 +656,54 @@ def test_invalid_inner_type_cast_list() -> None:
656656
s.cast(pl.List(pl.Categorical))
657657

658658

659+
@pytest.mark.parametrize(
660+
("values", "result"),
661+
[
662+
([[]], [b""]),
663+
([[1, 2], [3, 4]], [b"\x01\x02", b"\x03\x04"]),
664+
([[1, 2], None, [3, 4]], [b"\x01\x02", None, b"\x03\x04"]),
665+
(
666+
[None, [111, 110, 101], [12, None], [116, 119, 111], list(range(256))],
667+
[
668+
None,
669+
b"one",
670+
# A list with a null in it gets turned into a null:
671+
None,
672+
b"two",
673+
bytes(i for i in range(256)),
674+
],
675+
),
676+
],
677+
)
678+
def test_list_uint8_to_bytes(
679+
values: list[list[int | None] | None], result: list[bytes | None]
680+
) -> None:
681+
s = pl.Series(
682+
values,
683+
dtype=pl.List(pl.UInt8()),
684+
)
685+
assert s.cast(pl.Binary(), strict=False).to_list() == result
686+
687+
688+
def test_list_uint8_to_bytes_strict() -> None:
689+
series = pl.Series(
690+
[[1, 2], [3, 4]],
691+
dtype=pl.List(pl.UInt8()),
692+
)
693+
assert series.cast(pl.Binary(), strict=True).to_list() == [b"\x01\x02", b"\x03\x04"]
694+
695+
series = pl.Series(
696+
"mycol",
697+
[[1, 2], [3, None]],
698+
dtype=pl.List(pl.UInt8()),
699+
)
700+
with pytest.raises(
701+
InvalidOperationError,
702+
match="conversion from `list\\[u8\\]` to `binary` failed in column 'mycol' for 1 out of 2 values: \\[\\[3, null\\]\\]",
703+
):
704+
series.cast(pl.Binary(), strict=True)
705+
706+
659707
def test_all_null_cast_5826() -> None:
660708
df = pl.DataFrame(data=[pl.Series("a", [None], dtype=pl.String)])
661709
out = df.with_columns(pl.col("a").cast(pl.Boolean))

0 commit comments

Comments
 (0)