Skip to content

feat: add expression array_size #1122

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Prev Previous commit
Next Next commit
Merge branch 'main' of github.com:Groennbeck/datafusion-comet into ar…
…ray-size
  • Loading branch information
Groennbeck committed Feb 10, 2025
commit b078747b08c061a45ab4ab07f7c303b2433a3208
84 changes: 84 additions & 0 deletions native/core/src/execution/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -752,6 +752,90 @@ impl PhysicalPlanner {
)?;
Ok(Arc::new(ArraySize::new(child)))
}
ExprStruct::ArrayRemove(expr) => {
let src_array_expr =
self.create_expr(expr.left.as_ref().unwrap(), Arc::clone(&input_schema))?;
let key_expr =
self.create_expr(expr.right.as_ref().unwrap(), Arc::clone(&input_schema))?;
let args = vec![Arc::clone(&src_array_expr), Arc::clone(&key_expr)];
let return_type = src_array_expr.data_type(&input_schema)?;

let datafusion_array_remove = array_remove_all_udf();

let array_remove_expr: Arc<dyn PhysicalExpr> = Arc::new(ScalarFunctionExpr::new(
"array_remove",
datafusion_array_remove,
args,
return_type,
));
let is_null_expr: Arc<dyn PhysicalExpr> = Arc::new(IsNullExpr::new(key_expr));

let null_literal_expr: Arc<dyn PhysicalExpr> =
Arc::new(Literal::new(ScalarValue::Null));

create_case_expr(
vec![(is_null_expr, null_literal_expr)],
Some(array_remove_expr),
&input_schema,
)
}
ExprStruct::ArrayIntersect(expr) => {
let left_expr =
self.create_expr(expr.left.as_ref().unwrap(), Arc::clone(&input_schema))?;
let right_expr =
self.create_expr(expr.right.as_ref().unwrap(), Arc::clone(&input_schema))?;
let args = vec![Arc::clone(&left_expr), right_expr];
let datafusion_array_intersect = array_intersect_udf();
let return_type = left_expr.data_type(&input_schema)?;
let array_intersect_expr = Arc::new(ScalarFunctionExpr::new(
"array_intersect",
datafusion_array_intersect,
args,
return_type,
));
Ok(array_intersect_expr)
}
ExprStruct::ArrayJoin(expr) => {
let array_expr =
self.create_expr(expr.array_expr.as_ref().unwrap(), Arc::clone(&input_schema))?;
let delimiter_expr = self.create_expr(
expr.delimiter_expr.as_ref().unwrap(),
Arc::clone(&input_schema),
)?;

let mut args = vec![Arc::clone(&array_expr), delimiter_expr];
if expr.null_replacement_expr.is_some() {
let null_replacement_expr = self.create_expr(
expr.null_replacement_expr.as_ref().unwrap(),
Arc::clone(&input_schema),
)?;
args.push(null_replacement_expr)
}

let datafusion_array_to_string = array_to_string_udf();
let array_join_expr = Arc::new(ScalarFunctionExpr::new(
"array_join",
datafusion_array_to_string,
args,
DataType::Utf8,
));
Ok(array_join_expr)
}
ExprStruct::ArraysOverlap(expr) => {
let left_array_expr =
self.create_expr(expr.left.as_ref().unwrap(), Arc::clone(&input_schema))?;
let right_array_expr =
self.create_expr(expr.right.as_ref().unwrap(), Arc::clone(&input_schema))?;
let args = vec![Arc::clone(&left_array_expr), right_array_expr];
let datafusion_array_has_any = array_has_any_udf();
let array_has_any_expr = Arc::new(ScalarFunctionExpr::new(
"array_has_any",
datafusion_array_has_any,
args,
DataType::Boolean,
));
Ok(array_has_any_expr)
}
expr => Err(ExecutionError::GeneralError(format!(
"Not implemented: {:?}",
expr
Expand Down
12 changes: 11 additions & 1 deletion native/proto/src/proto/expr.proto
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,11 @@ message Expr {
BinaryExpr array_append = 58;
ArrayInsert array_insert = 59;
BinaryExpr array_contains = 60;
ArraySize array_size = 61;
BinaryExpr array_remove = 61;
BinaryExpr array_intersect = 62;
ArrayJoin array_join = 63;
BinaryExpr arrays_overlap = 64;
ArraySize array_size = 65;
}
}

Expand Down Expand Up @@ -418,6 +422,12 @@ message ArraySize {
Expr src_array_expr = 1;
}

message ArrayJoin {
Expr array_expr = 1;
Expr delimiter_expr = 2;
Expr null_replacement_expr = 3;
}

message DataType {
enum DataTypeId {
BOOL = 0;
Expand Down
84 changes: 0 additions & 84 deletions native/spark-expr/src/array_funcs/array_insert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -317,90 +317,6 @@ impl Display for ArrayInsert {
}
}

impl PartialEq for ArraySize {
fn eq(&self, other: &Self) -> bool {
self.src_array_expr.eq(&other.src_array_expr)
}
}


#[derive(Debug, Eq)]
pub struct ArraySize {
src_array_expr: Arc<dyn PhysicalExpr>,
}

impl ArraySize {
pub fn new(src_array_expr: Arc<dyn PhysicalExpr>) -> Self {
Self { src_array_expr }
}
}

impl Display for ArraySize {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "ArraySize [array: {:?}]", self.src_array_expr)
}
}

impl Hash for ArraySize {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.src_array_expr.hash(state);
}
}

impl PhysicalExpr for ArraySize {
fn as_any(&self) -> &dyn Any {
self
}

fn data_type(&self, _input_schema: &Schema) -> DataFusionResult<DataType> {
Ok(DataType::Int32)
}

fn nullable(&self, input_schema: &Schema) -> DataFusionResult<bool> {
self.src_array_expr.nullable(input_schema)
}

fn evaluate(&self, batch: &RecordBatch) -> DataFusionResult<ColumnarValue> {
let array_value = self
.src_array_expr
.evaluate(batch)?
.into_array(batch.num_rows())?;
match array_value.data_type() {
DataType::List(_) => {
let list_array = as_list_array(&array_value)?;
let mut builder = Int32Array::builder(list_array.len());
for i in 0..list_array.len() {
if list_array.is_null(i) {
builder.append_null();
} else {
builder.append_value(list_array.value_length(i));
}
}
let sizes_array = Int32Array::from(builder.finish());
Ok(ColumnarValue::Array(Arc::new(sizes_array)))
}
_ => Err(DataFusionError::Internal(format!(
"Unexpected data type in ArraySize: {:?}",
array_value.data_type()
))),
}
}

fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
vec![&self.src_array_expr]
}

fn with_new_children(
self: Arc<Self>,
children: Vec<Arc<dyn PhysicalExpr>>,
) -> datafusion_common::Result<Arc<dyn PhysicalExpr>> {
match children.len() {
1 => Ok(Arc::new(ArraySize::new(Arc::clone(&children[0])))),
_ => internal_err!("ArraySize should have exactly one child"),
}
}
}

#[cfg(test)]
mod test {
use super::*;
Expand Down
95 changes: 95 additions & 0 deletions native/spark-expr/src/array_funcs/array_size.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
use std::any::Any;
use std::fmt::{Display, Formatter};
use std::hash::Hash;
use std::sync::Arc;
use arrow_array::{Array, Int32Array, RecordBatch};
use arrow_schema::{DataType, Schema};
use datafusion::physical_expr::PhysicalExpr;
use datafusion_common::cast::as_list_array;
use datafusion_common::{internal_err, DataFusionError, Result as DataFusionResult};
use datafusion_expr_common::columnar_value::ColumnarValue;



#[derive(Debug, Eq)]
pub struct ArraySize {
src_array_expr: Arc<dyn PhysicalExpr>,
}

impl PartialEq for ArraySize {
fn eq(&self, other: &Self) -> bool {
self.src_array_expr.eq(&other.src_array_expr)
}
}

impl ArraySize {
pub fn new(src_array_expr: Arc<dyn PhysicalExpr>) -> Self {
Self { src_array_expr }
}
}

impl Display for ArraySize {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "ArraySize [array: {:?}]", self.src_array_expr)
}
}

impl Hash for ArraySize {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.src_array_expr.hash(state);
}
}

impl PhysicalExpr for ArraySize {
fn as_any(&self) -> &dyn Any {
self
}

fn data_type(&self, _input_schema: &Schema) -> DataFusionResult<DataType> {
Ok(DataType::Int32)
}

fn nullable(&self, input_schema: &Schema) -> DataFusionResult<bool> {
self.src_array_expr.nullable(input_schema)
}

fn evaluate(&self, batch: &RecordBatch) -> DataFusionResult<ColumnarValue> {
let array_value = self
.src_array_expr
.evaluate(batch)?
.into_array(batch.num_rows())?;
match array_value.data_type() {
DataType::List(_) => {
let list_array = as_list_array(&array_value)?;
let mut builder = Int32Array::builder(list_array.len());
for i in 0..list_array.len() {
if list_array.is_null(i) {
builder.append_null();
} else {
builder.append_value(list_array.value_length(i));
}
}
let sizes_array = Int32Array::from(builder.finish());
Ok(ColumnarValue::Array(Arc::new(sizes_array)))
}
_ => Err(DataFusionError::Internal(format!(
"Unexpected data type in ArraySize: {:?}",
array_value.data_type()
))),
}
}

fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
vec![&self.src_array_expr]
}

fn with_new_children(
self: Arc<Self>,
children: Vec<Arc<dyn PhysicalExpr>>,
) -> datafusion_common::Result<Arc<dyn PhysicalExpr>> {
match children.len() {
1 => Ok(Arc::new(ArraySize::new(Arc::clone(&children[0])))),
_ => internal_err!("ArraySize should have exactly one child"),
}
}
}
1 change: 1 addition & 0 deletions native/spark-expr/src/array_funcs/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
mod array_insert;
mod get_array_struct_fields;
mod list_extract;
mod array_size;

pub use array_insert::ArrayInsert;
pub use get_array_struct_fields::GetArrayStructFields;
Expand Down
8 changes: 1 addition & 7 deletions native/spark-expr/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,7 @@ mod agg_funcs;
mod array_funcs;
mod bitwise_funcs;
mod comet_scalar_funcs;
pub use cast::{spark_cast, Cast, SparkCastOptions};
pub use comet_scalar_funcs::create_comet_physical_fun;
pub use error::{SparkError, SparkResult};
pub use if_expr::IfExpr;
pub use list::{ArrayInsert, ArraySize, GetArrayStructFields, ListExtract};
pub use regexp::RLike;
pub use struct_funcs::*;
pub mod hash_funcs;

mod string_funcs;

Expand Down
Loading
You are viewing a condensed version of this merge commit. You can view the full changes here.