Skip to content

feat: Allow casting List<UInt8> to Binary #22611

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 20 commits into from
May 8, 2025
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
887d297
High-level dispatch for casting List(Uint8) to Binary.
pythonspeed Apr 30, 2025
5af6f01
WIP Sketch of low-level Arrow casting from List(Uint8) to BinaryView.
pythonspeed Apr 30, 2025
a24059d
Merge remote-tracking branch 'origin/main' into 21549-cast-list-uint8…
pythonspeed May 5, 2025
e140208
Get List<u8> to Binary casting working.
pythonspeed May 5, 2025
7d28526
More testing.
pythonspeed May 5, 2025
960d106
Lints and reformatting.
pythonspeed May 5, 2025
eacbdd1
Look up values on creation, potentially optimizing code.
pythonspeed May 5, 2025
027dc4c
Merge remote-tracking branch 'origin/main' into 21549-cast-list-uint8…
pythonspeed May 6, 2025
6527918
Reuse the existing buffer.
pythonspeed May 6, 2025
8f3268a
Test another edge case
pythonspeed May 6, 2025
9f85e53
Handle buffers >= 2^32 bytes.
pythonspeed May 6, 2025
5c22159
Optimization: if all views are inline, we don't need to keep a refere…
pythonspeed May 6, 2025
0b02563
Remove optimization.
pythonspeed May 7, 2025
513fae7
Don't bother using MutableBinaryViewArray.
pythonspeed May 7, 2025
8de12f0
Catch lists that don't fit in single buffer to match Arrow spec.
pythonspeed May 7, 2025
1470da4
A better assertion.
pythonspeed May 7, 2025
1b7dd97
Optimization: don't need buffers if all Views are inline
pythonspeed May 7, 2025
817f958
Fix bug in splitting, improve tests, and add a test.
pythonspeed May 8, 2025
2a4dae3
Assert the data is still there!
pythonspeed May 8, 2025
674c82a
Pacify clippy.
pythonspeed May 8, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
208 changes: 207 additions & 1 deletion crates/polars-compute/src/cast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ mod dictionary_to;
mod primitive_to;
mod utf8_to;

use arrow::bitmap::MutableBitmap;
pub use binary_to::*;
#[cfg(feature = "dtype-decimal")]
pub use binview_to::binview_to_decimal;
Expand Down Expand Up @@ -237,6 +238,125 @@ pub(super) fn cast_list_to_fixed_size_list<O: Offset>(
.map_err(|_| polars_err!(ComputeError: "not all elements have the specified width {size}"))
}

fn cast_list_uint8_to_binary<O: Offset>(list: &ListArray<O>) -> PolarsResult<BinaryViewArray> {
let mut views = Vec::with_capacity(list.len());
let mut result_validity = MutableBitmap::from_len_set(list.len());

let u8array: &PrimitiveArray<u8> = list.values().as_any().downcast_ref().unwrap();
let slice = u8array.values().as_slice();
let mut cloned_buffers = vec![u8array.values().clone()];
let mut buf_index = 0;
let mut previous_buf_lengths = 0;
let validity = list.validity();
let internal_validity = list.values().validity();
let offsets = list.offsets();

let mut all_views_inline = true;

// In a View for BinaryViewArray, both length and offset are u32.
#[cfg(not(test))]
const MAX_BUF_SIZE: usize = u32::MAX as usize;

// This allows us to test some invariants without using 4GB of RAM; see mod
// tests below.
#[cfg(test)]
const MAX_BUF_SIZE: usize = 3;

for index in 0..list.len() {
// Check if there's a null instead of a list:
if let Some(validity) = validity {
// SAFETY: We are generating indexes limited to < list.len().
debug_assert!(index < validity.len());
if unsafe { !validity.get_bit_unchecked(index) } {
debug_assert!(index < result_validity.len());
unsafe {
result_validity.set_unchecked(index, false);
}
views.push(View::default());
continue;
}
}

// SAFETY: We are generating indexes limited to < list.len().
debug_assert!(index < offsets.len());
let (start, end) = unsafe { offsets.start_end_unchecked(index) };
let length = end - start;
polars_ensure!(
length <= MAX_BUF_SIZE,
InvalidOperation: "when casting to BinaryView, list lengths must fit in u32"
);

// Check if the list contains nulls:
if let Some(internal_validity) = internal_validity {
if internal_validity.null_count_range(start, length) > 0 {
debug_assert!(index < result_validity.len());
unsafe {
result_validity.set_unchecked(index, false);
}
views.push(View::default());
continue;
}
}

if start - previous_buf_lengths > MAX_BUF_SIZE {
// View offsets must fit in u32 (or smaller value when running Rust
// tests).
buf_index += 1;
let (previous, next) = cloned_buffers
.last()
.unwrap()
.split_at(start - previous_buf_lengths);
previous_buf_lengths += previous.len();
*(cloned_buffers.last_mut().unwrap()) = previous;
cloned_buffers.push(next);
}
let view = View::new_from_bytes(
&slice[start..end],
buf_index,
(start - previous_buf_lengths) as u32,
);
if !view.is_inline() {
all_views_inline = false;
}
debug_assert_eq!(
unsafe { view.get_slice_unchecked(&cloned_buffers) },
&slice[start..end]
);
views.push(view);
}

// Optimization: don't actually need buffers if Views are all inline.
if all_views_inline {
cloned_buffers.clear();
}

let result_buffers = cloned_buffers.into_boxed_slice().into();
let result = if cfg!(debug_assertions) {
// A safer wrapper around new_unchecked_unknown_md; it shouldn't ever
// fail in practice.
BinaryViewArrayGeneric::try_new(
ArrowDataType::BinaryView,
views.into(),
result_buffers,
result_validity.into(),
)?
} else {
unsafe {
BinaryViewArrayGeneric::new_unchecked_unknown_md(
ArrowDataType::BinaryView,
views.into(),
result_buffers,
result_validity.into(),
// We could compute this ourselves, but we want to make this code
// match debug_assertions path as much as possible.
None,
)
}
};

Ok(result)
}

pub fn cast_default(array: &dyn Array, to_type: &ArrowDataType) -> PolarsResult<Box<dyn Array>> {
cast(array, to_type, Default::default())
}
Expand All @@ -258,6 +378,7 @@ pub fn cast_unchecked(array: &dyn Array, to_type: &ArrowDataType) -> PolarsResul
/// * Fixed Size List to List: the underlying data type is cast
/// * List to Fixed Size List: the offsets are checked for valid order, then the
/// underlying type is cast.
/// * List of UInt8 to Binary: the list of integers becomes binary data, nulls in the list means it becomes a null
/// * Struct to Struct: the underlying fields are cast.
/// * PrimitiveArray to List: a list array with 1 value per slot is created
/// * Date32 and Date64: precision lost when going to higher interval
Expand All @@ -267,7 +388,7 @@ pub fn cast_unchecked(array: &dyn Array, to_type: &ArrowDataType) -> PolarsResul
///
/// Unsupported Casts
/// * non-`StructArray` to `StructArray` or `StructArray` to non-`StructArray`
/// * List to primitive
/// * List to primitive (other than UInt8)
/// * Utf8 to boolean
/// * Interval and duration
pub fn cast(
Expand Down Expand Up @@ -326,6 +447,14 @@ pub fn cast(
options,
)
.map(|x| x.boxed()),
(List(field), BinaryView) if matches!(field.dtype(), UInt8) => {
cast_list_uint8_to_binary::<i32>(array.as_any().downcast_ref().unwrap())
.map(|arr| arr.boxed())
},
(LargeList(field), BinaryView) if matches!(field.dtype(), UInt8) => {
cast_list_uint8_to_binary::<i64>(array.as_any().downcast_ref().unwrap())
.map(|arr| arr.boxed())
},
(BinaryView, _) => match to_type {
Utf8View => array
.as_any()
Expand Down Expand Up @@ -853,3 +982,80 @@ fn from_to_binview(
};
Ok(binview)
}

#[cfg(test)]
mod tests {
use arrow::offset::OffsetsBuffer;
use polars_error::PolarsError;

use super::*;

/// When cfg(test), offsets for ``View``s generated by
/// cast_list_uint8_to_binary() are limited to max value of 3, so buffers
/// need to be split aggressively.
#[test]
fn cast_list_uint8_to_binary_across_buffer_max_size() {
let dtype =
ArrowDataType::List(Box::new(Field::new("".into(), ArrowDataType::UInt8, true)));
let list_u8 = cast(
// Values become individual lists:
&PrimitiveArray::from_slice([1u8, 2, 3, 4, 5, 6, 7]),
&dtype,
CastOptionsImpl {
wrapped: false,
partial: true,
},
)
.unwrap();

let binary = cast(
list_u8.as_ref(),
&ArrowDataType::BinaryView,
CastOptionsImpl::default(),
)
.unwrap();
let binary_array: &BinaryViewArray = binary.as_ref().as_any().downcast_ref().unwrap();
assert_eq!(
binary_array.values_iter().collect::<Vec<&[u8]>>(),
vec![&[1], &[2], &[3], &[4], &[5], &[6], &[7]]
);
// max offset of 3 so we need to split:
assert_eq!(
binary_array
.data_buffers()
.iter()
.map(|buf| buf.len())
.collect::<Vec<_>>(),
vec![4, 3]
);
}

/// Arrow spec requires views to fit in a single buffer. When cfg(test),
/// buffers generated by cast_list_uint8_to_binary are of size 3 or smaller,
/// so a list of size 4 should cause an error.
#[test]
fn cast_list_uint8_to_binary_errors_too_large_list() {
let values = PrimitiveArray::from_slice([1u8, 2, 3, 4]);
let dtype =
ArrowDataType::List(Box::new(Field::new("".into(), ArrowDataType::UInt8, true)));
let list_u8 = ListArray::new(
dtype,
OffsetsBuffer::one_with_length(4),
values.boxed(),
None,
);

let err = cast(
&list_u8,
&ArrowDataType::BinaryView,
CastOptionsImpl::default(),
)
.unwrap_err();
if let PolarsError::InvalidOperation(msg) = err {
assert_eq!(
msg.as_ref(),
"when casting to BinaryView, list lengths must fit in u32"
);
}
}
}
19 changes: 19 additions & 0 deletions crates/polars-core/src/chunked_array/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -513,6 +513,25 @@ impl ChunkCast for ListChunked {
))
}
},
#[cfg(feature = "dtype-u8")]
Binary => {
polars_ensure!(
matches!(self.inner_dtype(), UInt8),
InvalidOperation: "cannot cast List type (inner: '{:?}', to: '{:?}')",
self.inner_dtype(),
dtype,
);
let chunks = cast_chunks(self.chunks(), &DataType::Binary, options)?;

// SAFETY: we just cast so the dtype matches.
unsafe {
Ok(Series::from_chunks_and_dtype_unchecked(
self.name().clone(),
chunks,
&DataType::Binary,
))
}
},
_ => {
polars_bail!(
InvalidOperation: "cannot cast List type (inner: '{:?}', to: '{:?}')",
Expand Down
48 changes: 48 additions & 0 deletions py-polars/tests/unit/operations/test_cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -656,6 +656,54 @@ def test_invalid_inner_type_cast_list() -> None:
s.cast(pl.List(pl.Categorical))


@pytest.mark.parametrize(
("values", "result"),
[
([[]], [b""]),
([[1, 2], [3, 4]], [b"\x01\x02", b"\x03\x04"]),
([[1, 2], None, [3, 4]], [b"\x01\x02", None, b"\x03\x04"]),
(
[None, [111, 110, 101], [12, None], [116, 119, 111], list(range(256))],
[
None,
b"one",
# A list with a null in it gets turned into a null:
None,
b"two",
bytes(i for i in range(256)),
],
),
],
)
def test_list_uint8_to_bytes(
values: list[list[int | None] | None], result: list[bytes | None]
) -> None:
s = pl.Series(
values,
dtype=pl.List(pl.UInt8()),
)
assert s.cast(pl.Binary(), strict=False).to_list() == result


def test_list_uint8_to_bytes_strict() -> None:
series = pl.Series(
[[1, 2], [3, 4]],
dtype=pl.List(pl.UInt8()),
)
assert series.cast(pl.Binary(), strict=True).to_list() == [b"\x01\x02", b"\x03\x04"]

series = pl.Series(
"mycol",
[[1, 2], [3, None]],
dtype=pl.List(pl.UInt8()),
)
with pytest.raises(
InvalidOperationError,
match="conversion from `list\\[u8\\]` to `binary` failed in column 'mycol' for 1 out of 2 values: \\[\\[3, null\\]\\]",
):
series.cast(pl.Binary(), strict=True)


def test_all_null_cast_5826() -> None:
df = pl.DataFrame(data=[pl.Series("a", [None], dtype=pl.String)])
out = df.with_columns(pl.col("a").cast(pl.Boolean))
Expand Down
Loading